Repository: neuronets/nobrainer Branch: master Commit: 3109460d048b Files: 220 Total size: 898.7 KB Directory structure: gitextract_7t995rdz/ ├── .autorc ├── .dockerignore ├── .flake8 ├── .gitattributes ├── .github/ │ ├── EC2_GPU_RUNNER.md │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── documentation.md │ │ ├── feature_request.md │ │ ├── maintenance.md │ │ └── question.md │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows/ │ ├── ci.yml │ ├── guide-notebooks-ec2.yml │ ├── kwyk-reproduction-ec2.yml │ ├── publish.yml │ ├── release.yml │ └── validate-book.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .zenodo.json ├── CHANGELOG.md ├── CITATION ├── CLAUDE.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── conftest.py ├── docker/ │ ├── README.md │ ├── cpu.Dockerfile │ └── gpu.Dockerfile ├── nobrainer/ │ ├── __init__.py │ ├── _version.py │ ├── augmentation/ │ │ ├── __init__.py │ │ ├── profiles.py │ │ ├── synthseg.py │ │ └── transforms.py │ ├── cli/ │ │ ├── __init__.py │ │ ├── main.py │ │ └── tests/ │ │ ├── __init__.py │ │ └── main_test.py │ ├── dataset.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── openneuro.py │ │ └── zarr_store.py │ ├── distributed_learning/ │ │ └── dwc.py │ ├── experiment.py │ ├── gpu.py │ ├── io.py │ ├── layers/ │ │ ├── InstanceNorm.py │ │ ├── __init__.py │ │ ├── bernoulli_dropout.py │ │ ├── concrete_dropout.py │ │ ├── gaussian_dropout.py │ │ ├── maxpool4d.py │ │ ├── padding.py │ │ └── tests/ │ │ └── __init__.py │ ├── losses.py │ ├── metrics.py │ ├── models/ │ │ ├── __init__.py │ │ ├── _constants.py │ │ ├── _utils.py │ │ ├── autoencoder.py │ │ ├── bayesian/ │ │ │ ├── __init__.py │ │ │ ├── bayesian_meshnet.py │ │ │ ├── bayesian_vnet.py │ │ │ ├── kwyk_meshnet.py │ │ │ ├── layers.py │ │ │ ├── utils.py │ │ │ ├── vwn_layers.py │ │ │ └── warmstart.py │ │ ├── generative/ │ │ │ ├── __init__.py │ │ │ ├── dcgan.py │ │ │ └── progressivegan.py │ │ ├── highresnet.py │ │ ├── meshnet.py │ │ ├── segformer3d.py │ │ ├── segmentation.py │ │ ├── simsiam.py │ │ └── tests/ │ │ └── __init__.py │ ├── prediction.py │ ├── processing/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── croissant.py │ │ ├── dataset.py │ │ ├── generation.py │ │ └── segmentation.py │ ├── research/ │ │ ├── __init__.py │ │ ├── loop.py │ │ └── templates/ │ │ ├── .gitkeep │ │ ├── prepare.py │ │ └── train_bayesian_vnet.py │ ├── slurm.py │ ├── sr-tests/ │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_bayesian_uncertainty.py │ │ ├── test_brain_generation.py │ │ ├── test_croissant_metadata.py │ │ ├── test_dataset_builder.py │ │ ├── test_extract_patches.py │ │ ├── test_kwyk_smoke.py │ │ ├── test_raw_pytorch_api.py │ │ ├── test_segmentation_estimator.py │ │ ├── test_synthseg_brain.py │ │ ├── test_zarr_conversion.py │ │ └── test_zarr_pipeline.py │ ├── tests/ │ │ ├── __init__.py │ │ ├── contract/ │ │ │ ├── __init__.py │ │ │ └── test_cli.py │ │ ├── gpu/ │ │ │ ├── __init__.py │ │ │ ├── test_bayesian_e2e.py │ │ │ ├── test_gan_e2e.py │ │ │ ├── test_multi_gpu.py │ │ │ └── test_predict_e2e.py │ │ ├── integration/ │ │ │ ├── __init__.py │ │ │ ├── test_datalad_commit.py │ │ │ └── test_research_smoke.py │ │ └── unit/ │ │ ├── __init__.py │ │ ├── test_bayesian_layers.py │ │ ├── test_bayesian_models.py │ │ ├── test_class_weights.py │ │ ├── test_croissant.py │ │ ├── test_dataset.py │ │ ├── test_dataset_builder.py │ │ ├── test_datasets_openneuro.py │ │ ├── test_estimator_generation.py │ │ ├── test_estimator_segmentation.py │ │ ├── test_experiment.py │ │ ├── test_generative.py │ │ ├── test_gpu.py │ │ ├── test_io_weights.py │ │ ├── test_io_zarr.py │ │ ├── test_layers.py │ │ ├── test_losses.py │ │ ├── test_metrics.py │ │ ├── test_model_interface.py │ │ ├── test_model_registry.py │ │ ├── test_models_segmentation.py │ │ ├── test_prediction.py │ │ ├── test_research_commit.py │ │ ├── test_research_loop.py │ │ ├── test_segformer3d.py │ │ ├── test_slurm.py │ │ ├── test_stride_patches.py │ │ ├── test_synthseg.py │ │ ├── test_training.py │ │ ├── test_training_convergence.py │ │ ├── test_transform_pipeline.py │ │ ├── test_vwn_layers.py │ │ ├── test_zarr_dataset.py │ │ └── test_zarr_store.py │ ├── training.py │ ├── utils.py │ └── validation.py ├── pyproject.toml └── scripts/ ├── kwyk_reproduction/ │ ├── 01_assemble_dataset.py │ ├── 02_train_meshnet.py │ ├── 03_train_bayesian.py │ ├── 04_evaluate.py │ ├── 05_compare_kwyk.py │ ├── 06_block_size_sweep.py │ ├── ARCHITECTURE.md │ ├── README.md │ ├── __init__.py │ ├── build_kwyk_manifest.py │ ├── config.yaml │ ├── config_kwyk_smoke.yaml │ ├── convert_zarr_shard.py │ ├── experiments/ │ │ ├── 01_20260330_eval_deterministic/ │ │ │ ├── README.md │ │ │ ├── eval_deterministic.py │ │ │ ├── results_summary.md │ │ │ └── run.sbatch │ │ ├── 02_20260330_binary_bayesian/ │ │ │ ├── README.md │ │ │ ├── config.yaml │ │ │ ├── eval_binary.py │ │ │ ├── eval_only.sbatch │ │ │ └── run.sbatch │ │ ├── 03_20260330_warmstart_diagnostic/ │ │ │ ├── README.md │ │ │ ├── diagnose.py │ │ │ ├── results_summary.md │ │ │ └── run.sbatch │ │ ├── 04_20260330_fixed_warmstart/ │ │ │ ├── README.md │ │ │ ├── run.py │ │ │ └── run.sbatch │ │ ├── 05_20260330_kwyk_from_scratch/ │ │ │ ├── README.md │ │ │ ├── results_summary.md │ │ │ ├── run.py │ │ │ └── run.sbatch │ │ ├── 06_20260331_fullvol_augment/ │ │ │ ├── README.md │ │ │ ├── config_256.yaml │ │ │ ├── config_256_mp.yaml │ │ │ ├── config_fullvol.yaml │ │ │ ├── run_128.sbatch │ │ │ ├── run_256.sbatch │ │ │ ├── run_256_a100.sbatch │ │ │ ├── run_256_gradckpt.sbatch │ │ │ └── run_256_mp.sbatch │ │ ├── 07_20260401_ddp_128/ │ │ │ ├── config.yaml │ │ │ └── run.sbatch │ │ ├── 08_20260401_ddp_128_full/ │ │ │ ├── config.yaml │ │ │ └── run.sbatch │ │ └── task-planner.md │ ├── label_mappings/ │ │ ├── 115-class-mapping.csv │ │ ├── 50-class-mapping.csv │ │ └── 6-class-mapping.csv │ ├── run.sh │ ├── slurm_convert_zarr.sbatch │ ├── slurm_kwyk_bayesian.sbatch │ ├── slurm_kwyk_evaluate.sbatch │ ├── slurm_kwyk_smoke.sbatch │ ├── slurm_train.sbatch │ ├── slurm_zarr_array.sbatch │ ├── submit_kwyk_smoke.sh │ └── utils.py └── synthseg_evaluation/ ├── 02_train.py ├── 03_evaluate.py ├── 04_compare.py ├── README.md ├── config.yaml ├── run.sh └── slurm_train.sbatch ================================================ FILE CONTENTS ================================================ ================================================ FILE: .autorc ================================================ { "onlyPublishWithReleaseLabel": false, "baseBranch": "master", "prereleaseBranches": ["alpha"], "author": "Nobrainer Bot ", "noVersionPrefix": true, "plugins": ["git-tag"] } ================================================ FILE: .dockerignore ================================================ .git/ docker/ .idea/ ================================================ FILE: .flake8 ================================================ [flake8] max-line-length = 100 exclude = .git/ __pycache__/ build/ dist/ _version.py versioneer.py ignore = E203 W503 ================================================ FILE: .gitattributes ================================================ nobrainer/_version.py export-subst ================================================ FILE: .github/EC2_GPU_RUNNER.md ================================================ # EC2 GPU Runner Setup This document describes how to configure the AWS EC2 instance used as a self-hosted GitHub Actions runner for GPU integration tests. The workflow (`guide-notebooks-ec2.yml`) uses [machulav/ec2-github-runner](https://github.com/machulav/ec2-github-runner) to start an ephemeral EC2 instance, run GPU tests, and terminate the instance automatically. ## AMI preparation Start from the **AWS Deep Learning Base AMI (Amazon Linux 2023)** or any Amazon Linux 2023 AMI with NVIDIA drivers and CUDA pre-installed. The AMI must be in the same region as the `AWS_REGION` variable configured in GitHub. ### 1. Launch an instance to build the AMI ```bash aws ec2 run-instances \ --image-id ami-XXXXXXXX \ --instance-type g4dn.xlarge \ --key-name your-key-pair \ --security-group-ids sg-XXXXXXXX \ --subnet-id subnet-XXXXXXXX \ --block-device-mappings '[{"DeviceName":"/dev/xvda","Ebs":{"VolumeSize":100}}]' \ --tag-specifications 'ResourceType=instance,Tags=[{Key=Name,Value=nobrainer-ami-builder}]' ``` ### 2. SSH in as ec2-user and configure ```bash ssh -i your-key.pem ec2-user@ ``` All commands below run as `ec2-user`. #### Install system dependencies ```bash sudo dnf install -y jq git ``` #### Install uv ```bash curl -LsSf https://astral.sh/uv/install.sh | sh source $HOME/.local/bin/env ``` #### Create the pre-installed nobrainer venv The CI workflow expects a venv at `~/nobrainer-env` with heavy dependencies (torch, monai, pyro-ppl) already installed. This avoids re-downloading ~2 GB of packages on every CI run. ```bash uv venv --python 3.14 ~/nobrainer-env # Install the heavy GPU dependencies into the base venv uv pip install \ torch \ monai \ pyro-ppl \ pytorch-lightning \ pytest ``` #### Verify GPU access ```bash source ~/nobrainer-env/bin/activate python -c " import torch assert torch.cuda.is_available(), 'CUDA not available' print(f'GPU: {torch.cuda.get_device_name(0)}') print(f'CUDA: {torch.version.cuda}') print(f'PyTorch: {torch.__version__}') " deactivate ``` ### 3. Create the AMI Stop the instance (or use `--no-reboot`), then: ```bash aws ec2 create-image \ --instance-id i-XXXXXXXXX \ --name "nobrainer-pytorch-gpu-$(date +%Y%m%d)" \ --description "Amazon Linux 2023 + CUDA + PyTorch + uv for nobrainer GPU CI" \ --no-reboot ``` Note the resulting AMI ID — this goes into the `AWS_IMAGE_ID` variable. ### 4. Terminate the builder instance ```bash aws ec2 terminate-instances --instance-id i-XXXXXXXXX ``` ## GitHub configuration ### Secrets (Settings → Secrets → Actions) | Name | Description | |------|-------------| | `AWS_KEY_ID` | IAM access key with EC2 RunInstances/TerminateInstances/DescribeInstances permissions | | `AWS_KEY_SECRET` | Corresponding secret access key | | `GH_TOKEN` | GitHub PAT with `repo` scope (used by machulav/ec2-github-runner to register the runner) | ### Variables (Settings → Variables → Actions) | Name | Example | Description | |------|---------|-------------| | `AWS_REGION` | `us-east-1` | Region where the AMI lives | | `AWS_IMAGE_ID` | `ami-0abc123def456` | The AMI created above | | `AWS_INSTANCE_TYPE` | `g4dn.xlarge` | 1x T4 GPU (~$0.53/hr); `p3.2xlarge` for V100 | | `AWS_SUBNET` | `subnet-0abc123` | Must have internet access for runner registration | | `AWS_SECURITY_GROUP` | `sg-0abc123` | Allow outbound HTTPS (port 443) | ## IAM policy (minimum permissions) ```json { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "ec2:RunInstances", "ec2:TerminateInstances", "ec2:DescribeInstances", "ec2:DescribeInstanceStatus", "ec2:CreateTags" ], "Resource": "*" }, { "Effect": "Allow", "Action": "iam:PassRole", "Resource": "*" } ] } ``` ## Updating the base venv When upgrading PyTorch or other dependencies, SSH into a running instance (or launch the AMI), update `~/nobrainer-env`, and create a new AMI snapshot: ```bash ssh -i your-key.pem ec2-user@ cd ~/nobrainer uv pip install --upgrade torch monai pyro-ppl pytorch-lightning # Then create a new AMI and update AWS_IMAGE_ID in GitHub variables ``` ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Report a bug (e.g., something not working as described, missing/incorrect documentation). title: '' labels: 'bug' assignees: '' --- ================================================ FILE: .github/ISSUE_TEMPLATE/documentation.md ================================================ --- name: Documentation improvement about: Request improvements to the documentation and tutorials. title: '' labels: 'documentation' assignees: '' --- ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Propose a new feature or a change to an existing feature. title: '' labels: 'feature' assignees: '' --- ================================================ FILE: .github/ISSUE_TEMPLATE/maintenance.md ================================================ --- name: Maintenance and delivery about: Suggestions and requests regarding the infrastructure for development, testing, and delivery. title: '' labels: 'maintenance' assignees: '' --- ================================================ FILE: .github/ISSUE_TEMPLATE/question.md ================================================ --- name: Question about: Not sure if you are using Nobrainer correctly, or other questions? This is the place. title: '' labels: 'question' assignees: '' --- ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Summary ## Checklist - [ ] I have added tests to cover my changes - [ ] I have updated documentation (if necessary) ## Acknowledgment - [ ] I acknowledge that this contribution will be available under the Apache 2 license. ================================================ FILE: .github/workflows/ci.yml ================================================ name: CI on: push: branches: [main, master] pull_request: branches: [main, master, alpha] jobs: unit-tests: runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: [ubuntu-latest] python-version: ["3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 fetch-tags: true - name: Install uv uses: astral-sh/setup-uv@v4 - name: Cache sample brain data uses: actions/cache@v4 with: path: /tmp/nobrainer-data key: nobrainer-sample-data-v1 - name: Set up Python ${{ matrix.python-version }} run: uv venv --python ${{ matrix.python-version }} - name: Install dependencies run: | uv pip install \ ".[bayesian,generative,zarr,dev]" \ monai \ pyro-ppl - name: Test with pytest (CPU, skip GPU) run: | uv run pytest nobrainer/tests/unit/ -v \ -m "not gpu" \ --no-header - name: Run sr-tests (somewhat realistic tests) run: | uv run pytest nobrainer/sr-tests/ -v \ -m "not gpu" \ --no-header \ --tb=short - name: Research loop smoke test (5 min budget, no API key) run: | mkdir -p /tmp/research-smoke cp nobrainer/research/templates/train_bayesian_vnet.py /tmp/research-smoke/train.py cp nobrainer/research/templates/prepare.py /tmp/research-smoke/prepare.py uv run nobrainer research run \ --working-dir /tmp/research-smoke \ --model-family meshnet \ --max-experiments 2 \ --budget-minutes 5 || true test -f /tmp/research-smoke/run_summary.md && echo "run_summary.md exists" || echo "WARN: no run_summary.md" image-build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 fetch-tags: true - name: Test CPU Docker image build run: | docker build -t neuronets/nobrainer:ci-cpu -f docker/cpu.Dockerfile . ================================================ FILE: .github/workflows/guide-notebooks-ec2.yml ================================================ name: GPU Tests - EC2 run-name: ${{ github.ref_name }} - GPU Tests - EC2 on: push: branches: [main, master] # PRs require approval label from a repo admin before this workflow runs. # This prevents untrusted PR code from executing on the self-hosted GPU runner. pull_request: branches: [main, master, alpha] types: [labeled, synchronize] jobs: start-runner: name: Start self-hosted EC2 runner runs-on: ubuntu-latest # Only run on PRs if an admin added the 'gpu-test-approved' label if: >- github.event_name == 'push' || (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'gpu-test-approved')) outputs: label: ${{ steps.start-ec2-runner.outputs.label }} ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} multi_gpu: ${{ steps.gpu-config.outputs.multi_gpu }} steps: - name: Parse GPU config labels id: gpu-config if: github.event_name == 'pull_request' run: | LABELS='${{ toJSON(github.event.pull_request.labels.*.name) }}' # Priority: gpu-instance: (exact) > gpu-multi (multi-GPU) > gpu-family: (family default) INSTANCE=$(echo "$LABELS" | jq -r '[.[] | select(startswith("gpu-instance:"))] | if length > 0 then .[0] | split(":")[1] else "" end') # gpu-multi label → pick a multi-GPU instance for DDP/model-parallel tests MULTI_GPU=$(echo "$LABELS" | jq -r '[.[] | select(. == "gpu-multi")] | if length > 0 then "true" else "" end') if [ -z "$INSTANCE" ] && [ -n "$MULTI_GPU" ]; then INSTANCE="g5.12xlarge" # 4x A10G GPUs echo "gpu-multi label → selecting $INSTANCE (4 GPUs)" fi # gpu-family: label → pick default instance from that family # Supported families: g4dn, g5, g6, p3, p4d, p5 if [ -z "$INSTANCE" ]; then FAMILY=$(echo "$LABELS" | jq -r '[.[] | select(startswith("gpu-family:"))] | if length > 0 then .[0] | split(":")[1] else "" end') if [ -n "$FAMILY" ]; then case "$FAMILY" in g4dn) INSTANCE="g4dn.xlarge" ;; # 1x T4 g5) INSTANCE="g5.xlarge" ;; # 1x A10G g6) INSTANCE="g6.xlarge" ;; # 1x L4 p3) INSTANCE="p3.2xlarge" ;; # 1x V100 p4d) INSTANCE="p4d.24xlarge" ;; # 8x A100 p5) INSTANCE="p5.48xlarge" ;; # 8x H100 *) echo "Unknown GPU family: $FAMILY"; INSTANCE="" ;; esac if [ -n "$INSTANCE" ]; then echo "gpu-family:${FAMILY} → selecting $INSTANCE" fi fi fi # Default to spot pricing; gpu-ondemand:true overrides to on-demand ONDEMAND=$(echo "$LABELS" | jq -r '[.[] | select(. == "gpu-ondemand:true")] | if length > 0 then "true" else "" end') if [ -n "$ONDEMAND" ]; then MARKET="" else MARKET="spot" fi echo "instance=${INSTANCE}" >> $GITHUB_OUTPUT echo "market_type=${MARKET}" >> $GITHUB_OUTPUT echo "multi_gpu=${MULTI_GPU}" >> $GITHUB_OUTPUT echo "Parsed labels: instance=${INSTANCE:-default}, market=${MARKET:-spot}, multi_gpu=${MULTI_GPU:-false}" - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v6 with: aws-access-key-id: ${{ secrets.AWS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_KEY_SECRET }} aws-region: ${{ vars.AWS_REGION }} - name: Start EC2 runner id: start-ec2-runner uses: machulav/ec2-github-runner@v2.5.2 with: mode: start github-token: ${{ secrets.GH_TOKEN }} ec2-image-id: ${{ vars.AWS_IMAGE_ID }} ec2-instance-type: ${{ steps.gpu-config.outputs.instance || vars.AWS_INSTANCE_TYPE }} subnet-id: ${{ vars.AWS_SUBNET }} security-group-id: ${{ vars.AWS_SECURITY_GROUP }} market-type: ${{ steps.gpu-config.outputs.market_type || 'spot' }} gpu-tests: needs: start-runner runs-on: ${{ needs.start-runner.outputs.label }} env: # The GitHub Actions runner runs as root on the EC2 instance, but # the AMI was set up as ec2-user. Use absolute paths to ec2-user's # home directory for the pre-installed venv and uv binary. EC2_USER_HOME: /home/ec2-user steps: - name: Checkout uses: actions/checkout@v4 - name: Log GPU runner config run: | echo "Instance type: $(curl -s http://169.254.169.254/latest/meta-data/instance-type 2>/dev/null || echo 'unknown')" echo "Availability zone: $(curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone 2>/dev/null || echo 'unknown')" echo "Market type: $(curl -s http://169.254.169.254/latest/meta-data/instance-life-cycle 2>/dev/null || echo 'unknown')" - name: Set up venv from pre-installed base run: | set -ex BASE_VENV="${EC2_USER_HOME}/nobrainer-env" export PATH="${EC2_USER_HOME}/.local/bin:$PATH" if [ -d "$BASE_VENV" ]; then echo "Found pre-installed base venv at $BASE_VENV" # Copy the base venv so the AMI stays clean for next run cp -a "$BASE_VENV" .venv else echo "No base venv found — creating from scratch" uv venv --python 3.14 fi # Install nobrainer from checkout on top of the base layer uv pip install \ ".[bayesian,generative,zarr,dev]" \ monai \ pyro-ppl \ matplotlib - name: Verify GPU access run: | export PATH="${EC2_USER_HOME}/.local/bin:$PATH" uv run python -c " import torch assert torch.cuda.is_available(), 'CUDA not available' n = torch.cuda.device_count() print(f'GPUs: {n}') for i in range(n): print(f' [{i}] {torch.cuda.get_device_name(i)}') print(f'CUDA: {torch.version.cuda}') print(f'PyTorch: {torch.__version__}') " - name: Run full test suite (including GPU) run: | export PATH="${EC2_USER_HOME}/.local/bin:$PATH" uv run pytest nobrainer/tests/ nobrainer/sr-tests/ -v \ --no-header \ --tb=short - name: Run multi-GPU tests (DDP + model parallel) if: needs.start-runner.outputs.multi_gpu == 'true' run: | export PATH="${EC2_USER_HOME}/.local/bin:$PATH" echo "=== Multi-GPU DDP and model-parallel tests ===" uv run pytest nobrainer/tests/gpu/ -v \ --no-header \ --tb=short \ -k "multi_gpu or ddp or model_parallel" stop-runner: name: Stop self-hosted EC2 runner needs: - start-runner - gpu-tests runs-on: ubuntu-latest if: ${{ always() && needs.start-runner.result == 'success' }} steps: - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v6 with: aws-access-key-id: ${{ secrets.AWS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_KEY_SECRET }} aws-region: ${{ vars.AWS_REGION }} - name: Stop EC2 runner uses: machulav/ec2-github-runner@v2.5.2 with: mode: stop github-token: ${{ secrets.GH_TOKEN }} label: ${{ needs.start-runner.outputs.label }} ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} ================================================ FILE: .github/workflows/kwyk-reproduction-ec2.yml ================================================ name: KWYK Reproduction - EC2 GPU run-name: ${{ github.ref_name }} - KWYK Reproduction on: workflow_dispatch: inputs: mode: description: "Run mode" required: true default: "smoke-test" type: choice options: - smoke-test - small-train - full instance_type: description: "EC2 instance type" required: false default: "" type: string on_demand: description: "Use on-demand pricing (not spot)" required: false default: false type: boolean pull_request: branches: [main, master, alpha] types: [labeled, synchronize] jobs: start-runner: name: Start self-hosted EC2 runner runs-on: ubuntu-latest if: >- github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'kwyk-gpu-test')) outputs: label: ${{ steps.start-ec2-runner.outputs.label }} ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} steps: - name: Determine instance config id: gpu-config run: | if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then INSTANCE="${{ inputs.instance_type }}" if [ "${{ inputs.on_demand }}" = "true" ]; then MARKET="" else MARKET="spot" fi else # PR: parse labels LABELS='${{ toJSON(github.event.pull_request.labels.*.name) }}' INSTANCE=$(echo "$LABELS" | jq -r '[.[] | select(startswith("gpu-instance:"))] | if length > 0 then .[0] | split(":")[1] else "" end') ONDEMAND=$(echo "$LABELS" | jq -r '[.[] | select(. == "gpu-ondemand:true")] | if length > 0 then "true" else "" end') if [ -n "$ONDEMAND" ]; then MARKET=""; else MARKET="spot"; fi fi echo "instance=${INSTANCE}" >> $GITHUB_OUTPUT echo "market_type=${MARKET}" >> $GITHUB_OUTPUT echo "Config: instance=${INSTANCE:-default}, market=${MARKET:-spot}" - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v6 with: aws-access-key-id: ${{ secrets.AWS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_KEY_SECRET }} aws-region: ${{ vars.AWS_REGION }} - name: Start EC2 runner id: start-ec2-runner uses: machulav/ec2-github-runner@v2.5.2 with: mode: start github-token: ${{ secrets.GH_TOKEN }} ec2-image-id: ${{ vars.AWS_IMAGE_ID }} ec2-instance-type: ${{ steps.gpu-config.outputs.instance || vars.AWS_INSTANCE_TYPE }} subnet-id: ${{ vars.AWS_SUBNET }} security-group-id: ${{ vars.AWS_SECURITY_GROUP }} market-type: ${{ steps.gpu-config.outputs.market_type || 'spot' }} kwyk-reproduction: needs: start-runner runs-on: ${{ needs.start-runner.outputs.label }} timeout-minutes: 120 env: EC2_USER_HOME: /home/ec2-user steps: - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 fetch-tags: true - name: Log GPU runner config run: | echo "Instance type: $(curl -s http://169.254.169.254/latest/meta-data/instance-type 2>/dev/null || echo 'unknown')" echo "Market type: $(curl -s http://169.254.169.254/latest/meta-data/instance-life-cycle 2>/dev/null || echo 'unknown')" - name: Install git-annex run: | # Runner runs as root but EC2_USER_HOME points to ec2-user. # Add both possible bin dirs to PATH. export PATH="/root/.local/bin:${EC2_USER_HOME}/.local/bin:$PATH" if ! command -v git-annex &>/dev/null; then echo "Installing git-annex via uv..." uv tool install git-annex fi git-annex version # Persist PATH for subsequent steps echo "/root/.local/bin" >> $GITHUB_PATH - name: Set up venv from pre-installed base run: | set -ex BASE_VENV="${EC2_USER_HOME}/nobrainer-env" export PATH="${EC2_USER_HOME}/.local/bin:$PATH" if [ -d "$BASE_VENV" ]; then echo "Found pre-installed base venv at $BASE_VENV" cp -a "$BASE_VENV" .venv else echo "No base venv found — creating from scratch" uv venv --python 3.14 fi uv pip install \ ".[bayesian,generative,zarr,versioning,dev]" \ monai pyro-ppl datalad matplotlib pyyaml scipy nibabel - name: Verify GPU access run: | export PATH="${EC2_USER_HOME}/.local/bin:$PATH" uv run python -c " import torch assert torch.cuda.is_available(), 'CUDA not available' print(f'GPU: {torch.cuda.get_device_name(0)}') print(f'CUDA: {torch.version.cuda}') print(f'PyTorch: {torch.__version__}') print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB') " - name: Cache sample brain data uses: actions/cache@v4 with: path: /tmp/nobrainer-data key: nobrainer-sample-data-v1 - name: Run kwyk sr-tests smoke test run: | export PATH="${EC2_USER_HOME}/.local/bin:$PATH" uv run pytest nobrainer/sr-tests/test_kwyk_smoke.py -v \ --no-header \ --tb=short - name: Determine run mode id: mode run: | if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then echo "mode=${{ inputs.mode }}" >> $GITHUB_OUTPUT else echo "mode=smoke-test" >> $GITHUB_OUTPUT fi - name: Run kwyk reproduction scripts (smoke test) if: steps.mode.outputs.mode == 'smoke-test' run: | set -ex export PATH="${EC2_USER_HOME}/.local/bin:$PATH" cd scripts/kwyk_reproduction # Smoke test: skip DataLad (requires git-annex on AMI). # The sr-tests already validated the pipeline with get_data(). # Here we verify the reproduction scripts parse configs correctly # and that the training loop works end-to-end with sample data. # Try DataLad first; fall back to sample brain data if it fails uv run python 01_assemble_dataset.py \ --datasets ds000114 \ --output-csv manifest.csv \ --output-dir data \ --label-mapping binary \ || { echo "DataLad assembly failed, falling back to sample brain data" uv run python -c " import csv; from nobrainer.utils import get_data src = get_data(); pairs = [] with open(src) as f: r = csv.reader(f); next(r) pairs = list(r)[:5] splits = ['train','train','train','val','test'] with open('manifest.csv', 'w', newline='') as f: w = csv.DictWriter(f, ['t1w_path','label_path','split']); w.writeheader() for i,(t1,lbl) in enumerate(pairs): w.writerow(dict(t1w_path=t1, label_path=lbl, split=splits[i])) print('Manifest created with', len(pairs), 'volumes') " } echo "=== Step 2: Train deterministic MeshNet (2 epochs) ===" uv run python 02_train_meshnet.py \ --manifest manifest.csv \ --config config.yaml \ --output-dir checkpoints/meshnet \ --epochs 2 echo "=== Step 3a: MC dropout variant ===" uv run python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bwn_multi \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bwn_multi \ --epochs 2 echo "=== Step 3b: Spike-and-slab variant (2 epochs) ===" uv run python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bvwn_multi_prior \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bvwn_multi_prior \ --epochs 2 echo "=== Checking outputs ===" ls -la checkpoints/meshnet/ checkpoints/bwn_multi/ checkpoints/bvwn_multi_prior/ 2>/dev/null || true - name: Run kwyk reproduction scripts (small training) if: steps.mode.outputs.mode == 'small-train' run: | set -ex export PATH="${EC2_USER_HOME}/.local/bin:$PATH" cd scripts/kwyk_reproduction # Try DataLad first; fall back to sample brain data if it fails uv run python 01_assemble_dataset.py \ --datasets ds000114 \ --output-csv manifest.csv \ --output-dir data \ --label-mapping binary \ || { echo "DataLad assembly failed, falling back to sample brain data" uv run python -c " import csv; from nobrainer.utils import get_data src = get_data(); pairs = [] with open(src) as f: r = csv.reader(f); next(r) pairs = list(r)[:5] splits = ['train','train','train','val','test'] with open('manifest.csv', 'w', newline='') as f: w = csv.DictWriter(f, ['t1w_path','label_path','split']); w.writeheader() for i,(t1,lbl) in enumerate(pairs): w.writerow(dict(t1w_path=t1, label_path=lbl, split=splits[i])) print('Manifest created with', len(pairs), 'volumes') " } echo "=== Step 2: Train deterministic MeshNet (20 epochs) ===" uv run python 02_train_meshnet.py \ --manifest manifest.csv \ --config config.yaml \ --output-dir checkpoints/meshnet \ --epochs 20 echo "=== Step 3a: MC dropout variant ===" uv run python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bwn_multi \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bwn_multi \ --epochs 20 echo "=== Step 3b: Spike-and-slab variant (20 epochs) ===" uv run python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bvwn_multi_prior \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bvwn_multi_prior \ --epochs 20 echo "=== Step 3c: Gaussian Bayesian variant (20 epochs) ===" uv run python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bayesian_gaussian \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bayesian_gaussian \ --epochs 20 echo "=== Checking outputs ===" ls -la checkpoints/*/ 2>/dev/null || true - name: Run kwyk reproduction scripts (full) if: steps.mode.outputs.mode == 'full' timeout-minutes: 1440 run: | set -ex export PATH="${EC2_USER_HOME}/.local/bin:$PATH" cd scripts/kwyk_reproduction uv run python 01_assemble_dataset.py \ --datasets ds000114 ds000228 ds002609 ds001021 ds002105 \ --output-csv manifest.csv \ --output-dir data \ --label-mapping binary --conform uv run python 02_train_meshnet.py \ --manifest manifest.csv \ --config config.yaml \ --output-dir checkpoints/meshnet \ --epochs 50 for variant in bwn_multi bvwn_multi_prior bayesian_gaussian; do echo "=== Training $variant (50 epochs) ===" uv run python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant $variant \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/$variant \ --epochs 50 done for variant in meshnet bwn_multi bvwn_multi_prior bayesian_gaussian; do echo "=== Evaluating $variant ===" uv run python 04_evaluate.py \ --model checkpoints/$variant/model.pth \ --manifest manifest.csv \ --split test \ --n-samples 10 \ --output-dir results/$variant done uv run python 05_compare_kwyk.py \ --new-model checkpoints/bvwn_multi_prior/model.pth \ --kwyk-dir ../../kwyk \ --manifest manifest.csv \ --output-dir results/comparison || echo "WARN: kwyk comparison failed (container may not be available)" uv run python 06_block_size_sweep.py \ --manifest manifest.csv \ --block-sizes 32 64 128 \ --output-dir results/sweep - name: Upload artifacts if: always() uses: actions/upload-artifact@v4 with: name: kwyk-reproduction-${{ steps.mode.outputs.mode }} path: | scripts/kwyk_reproduction/figures/ scripts/kwyk_reproduction/results/ scripts/kwyk_reproduction/checkpoints/*/croissant.json retention-days: 30 if-no-files-found: warn stop-runner: name: Stop self-hosted EC2 runner needs: - start-runner - kwyk-reproduction runs-on: ubuntu-latest if: ${{ always() && needs.start-runner.result == 'success' }} steps: - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v6 with: aws-access-key-id: ${{ secrets.AWS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_KEY_SECRET }} aws-region: ${{ vars.AWS_REGION }} - name: Stop EC2 runner uses: machulav/ec2-github-runner@v2.5.2 with: mode: stop github-token: ${{ secrets.GH_TOKEN }} label: ${{ needs.start-runner.outputs.label }} ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} ================================================ FILE: .github/workflows/publish.yml ================================================ name: Publish to PyPI on GitHub release on: release: types: [published] jobs: pypi-release: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Install uv uses: astral-sh/setup-uv@v4 - name: Build and publish run: | uv build uv publish env: UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }} ================================================ FILE: .github/workflows/release.yml ================================================ name: Auto-release on PR merge on: pull_request: branches: [master, alpha] types: [closed] env: AUTO_VERSION: v11.2.1 jobs: auto-release: name: Create release runs-on: ubuntu-latest # Stable release: merged PR to master with 'release' label # Alpha pre-release: merged PR to alpha (book validation runs as # a separate check via validate-book.yml on every PR push) if: >- github.event.pull_request.merged == true && ( contains(github.event.pull_request.labels.*.name, 'release') || github.event.pull_request.base.ref == 'alpha' ) steps: - uses: actions/checkout@v4 with: fetch-depth: 0 fetch-tags: true - name: Unset header run: git config --local --unset http.https://github.com/.extraheader - name: Download auto run: | auto_download_url="$(curl -fsSL https://api.github.com/repos/intuit/auto/releases/tags/$AUTO_VERSION | jq -r '.assets[] | select(.name == "auto-linux.gz") | .browser_download_url')" wget -O- "$auto_download_url" | gunzip > ~/auto chmod a+x ~/auto - name: Create release run: ~/auto shipit -vv env: GH_TOKEN: ${{ secrets.AUTO_USER_TOKEN }} ================================================ FILE: .github/workflows/validate-book.yml ================================================ name: Validate nobrainer-book tutorials on: workflow_dispatch: # Manual trigger only — not part of CI checks jobs: validate-book: name: Run nobrainer-book tutorials runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Install uv uses: astral-sh/setup-uv@v4 - name: Set up Python run: uv venv --python 3.14 - name: Install nobrainer and tutorial deps from PR branch run: | uv pip install \ ".[bayesian,generative,zarr,dev]" \ monai \ pyro-ppl \ matplotlib \ nilearn - name: Clone nobrainer-book (matching branch or alpha) run: | PR_BRANCH="${{ github.head_ref }}" BOOK_REPO="https://github.com/neuronets/nobrainer-book.git" # Try the PR's branch name first (for lockstep development), # fall back to alpha if git ls-remote --heads "$BOOK_REPO" "$PR_BRANCH" | grep -q .; then echo "Using matching book branch: $PR_BRANCH" git clone --branch "$PR_BRANCH" --depth 1 "$BOOK_REPO" /tmp/nobrainer-book else echo "No matching branch '$PR_BRANCH' on nobrainer-book, using alpha" git clone --branch alpha --depth 1 "$BOOK_REPO" /tmp/nobrainer-book fi - name: Run book tutorials run: | for script in /tmp/nobrainer-book/docs/nobrainer-guides/scripts/0*.py /tmp/nobrainer-book/docs/nobrainer-guides/scripts/1[01]*.py; do echo "=== Running $(basename $script) ===" uv run python "$script" || { echo "FAILED: $script" exit 1 } done echo "All book tutorials passed" ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # pycharm .idea/ # guide data guide/data/ # Model artifacts *.pth brain_mask_extraction_model/ data/ # Model artifacts *.pth brain_mask_extraction_model/ ================================================ FILE: .pre-commit-config.yaml ================================================ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks ci: skip: [codespell] repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: check-added-large-files - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black rev: 24.3.0 hooks: - id: black - repo: https://github.com/PyCQA/flake8 rev: 7.0.0 hooks: - id: flake8 - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - id: isort exclude: ^(nobrainer/_version\.py|versioneer\.py)$ - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: - id: codespell exclude: ^(nobrainer/_version\.py|versioneer\.py|pyproject\.toml|CHANGELOG\.md)$ ================================================ FILE: .zenodo.json ================================================ { "creators": [ { "affiliation": "Stony Brook University", "name": "Kaczmarzyk, Jakub", "orcid": "0000-0002-5544-7577" }, { "affiliation": "NIMH", "name": "McClure, Patrick" }, { "affiliation": "MIT", "name": "Zulfikar, Wazeer" }, { "affiliation": "MIT", "name": "Rana, Aakanksha", "orcid": "0000-0002-8350-7602" }, { "affiliation": "MIT", "name": "Rajaei, Hoda", "orcid": "0000-0002-0754-5586" }, { "affiliation": "University of Washington", "name": "Richie-Halford, Adam", "orcid": "0000-0001-9276-9084" }, { "affiliation": "Department of Psychology, Stanford University", "name": "Bansal, Shashank", "orcid": "0000-0002-1252-8772" }, { "affiliation": "MIT", "name": "Jarecka, Dorota", "orcid": "0000-0001-8282-2988" }, { "affiliation": "NIMH", "name": "Lee, John" }, { "affiliation": "MIT, HMS", "name": "Ghosh, Satrajit", "orcid": "0000-0002-5312-6729" } ], "keywords": [ "neuroimaging", "deep learning", "bayesian neural network" ], "license": "Apache-2.0", "upload_type": "software" } ================================================ FILE: CHANGELOG.md ================================================ # 1.2.1 (Thu Apr 04 2024) #### 🐛 Bug Fix - Fix PGAN notebook [#319](https://github.com/neuronets/nobrainer/pull/319) ([@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - fix dependencies [#318](https://github.com/neuronets/nobrainer/pull/318) ([@satra](https://github.com/satra)) - Update setup.cfg to add cuda option [#309](https://github.com/neuronets/nobrainer/pull/309) ([@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#294](https://github.com/neuronets/nobrainer/pull/294) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra) [@hvgazula](https://github.com/hvgazula)) #### Authors: 3 - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - H Gazula ([@hvgazula](https://github.com/hvgazula)) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 1.2.0 (Fri Mar 22 2024) #### 🚀 Enhancement - Dev [#295](https://github.com/neuronets/nobrainer/pull/295) ([@hvgazula](https://github.com/hvgazula) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Update setup.cfg [#299](https://github.com/neuronets/nobrainer/pull/299) ([@satra](https://github.com/satra)) #### 🐛 Bug Fix - Update release.yml ([@satra](https://github.com/satra)) - change hyphenation [#275](https://github.com/neuronets/nobrainer/pull/275) ([@satra](https://github.com/satra)) - update precommit checks [#275](https://github.com/neuronets/nobrainer/pull/275) ([@satra](https://github.com/satra)) - fix docker syntax [#275](https://github.com/neuronets/nobrainer/pull/275) ([@satra](https://github.com/satra)) - remove trained models [#275](https://github.com/neuronets/nobrainer/pull/275) ([@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#269](https://github.com/neuronets/nobrainer/pull/269) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### Authors: 3 - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - H Gazula ([@hvgazula](https://github.com/hvgazula)) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 1.1.1 (Sat Oct 07 2023) #### 🐛 Bug Fix - Small changes to support long, preemptable training runs [#267](https://github.com/neuronets/nobrainer/pull/267) ([@ohinds](https://github.com/ohinds) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#268](https://github.com/neuronets/nobrainer/pull/268) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### Authors: 3 - [@ohinds](https://github.com/ohinds) - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 1.1.0 (Tue Sep 19 2023) #### 🚀 Enhancement - Changes required to support the warmstart guide notebook [#266](https://github.com/neuronets/nobrainer/pull/266) ([@ohinds](https://github.com/ohinds) [@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### 🐛 Bug Fix - Fix some typos using codespell [#262](https://github.com/neuronets/nobrainer/pull/262) ([@yarikoptic](https://github.com/yarikoptic) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#263](https://github.com/neuronets/nobrainer/pull/263) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Remove unnecessary keepalive runner [#265](https://github.com/neuronets/nobrainer/pull/265) ([@ohinds](https://github.com/ohinds)) - Dynamically provision self-hosted runner [#264](https://github.com/neuronets/nobrainer/pull/264) ([@ohinds](https://github.com/ohinds)) #### Authors: 4 - [@ohinds](https://github.com/ohinds) - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Satrajit Ghosh ([@satra](https://github.com/satra)) - Yaroslav Halchenko ([@yarikoptic](https://github.com/yarikoptic)) --- # 1.0.0 (Thu Aug 31 2023) #### 💥 Breaking Change - `nobrainer` dataset API rework [#261](https://github.com/neuronets/nobrainer/pull/261) ([@ohinds](https://github.com/ohinds) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) #### 🐛 Bug Fix - Fix Dockerfiles and add GitHub Actions job [#252](https://github.com/neuronets/nobrainer/pull/252) ([@kabilar](https://github.com/kabilar) [@satra](https://github.com/satra)) - Self-hosted runner weekly keepalive [#260](https://github.com/neuronets/nobrainer/pull/260) ([@ohinds](https://github.com/ohinds)) - [pre-commit.ci] pre-commit autoupdate [#253](https://github.com/neuronets/nobrainer/pull/253) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Processing model checkpointing [#256](https://github.com/neuronets/nobrainer/pull/256) ([@ohinds](https://github.com/ohinds) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Verify that the ec2 instance we started is where we are running. [#257](https://github.com/neuronets/nobrainer/pull/257) ([@ohinds](https://github.com/ohinds)) - Training from warm start with multiple GPUs [#251](https://github.com/neuronets/nobrainer/pull/251) ([@ohinds](https://github.com/ohinds) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### Authors: 4 - [@ohinds](https://github.com/ohinds) - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Kabilar Gunalan ([@kabilar](https://github.com/kabilar)) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 0.5.0 (Wed Jul 19 2023) #### 🚀 Enhancement - Remove guide [#243](https://github.com/neuronets/nobrainer/pull/243) ([@ohinds](https://github.com/ohinds) [@satra](https://github.com/satra)) #### 🐛 Bug Fix - [pre-commit.ci] pre-commit autoupdate [#247](https://github.com/neuronets/nobrainer/pull/247) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Fix #246: Transforms now return labels if passed [#250](https://github.com/neuronets/nobrainer/pull/250) ([@ohinds](https://github.com/ohinds) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - Check out the master branch of the book repo on CI [#249](https://github.com/neuronets/nobrainer/pull/249) ([@ohinds](https://github.com/ohinds)) - Nobrainer book guide examples run on AWS EC2 as a form of regression testing [#248](https://github.com/neuronets/nobrainer/pull/248) ([@ohinds](https://github.com/ohinds)) - [pre-commit.ci] pre-commit autoupdate [#236](https://github.com/neuronets/nobrainer/pull/236) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Multi-GPU support [#242](https://github.com/neuronets/nobrainer/pull/242) ([@ohinds](https://github.com/ohinds) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#234](https://github.com/neuronets/nobrainer/pull/234) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#232](https://github.com/neuronets/nobrainer/pull/232) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#231](https://github.com/neuronets/nobrainer/pull/231) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### ⚠️ Pushed to `master` - Update setup.cfg ([@satra](https://github.com/satra)) - Update ci.yml ([@satra](https://github.com/satra)) - update python and tensorflow versions ([@satra](https://github.com/satra)) - [CI] update python and auto versions ([@satra](https://github.com/satra)) - replace special branch ([@satra](https://github.com/satra)) #### Authors: 3 - [@ohinds](https://github.com/ohinds) - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 0.4.0 (Tue Oct 18 2022) #### 🚀 Enhancement - update actions [#230](https://github.com/neuronets/nobrainer/pull/230) ([@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#229](https://github.com/neuronets/nobrainer/pull/229) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### 🐛 Bug Fix - Enh/api [#228](https://github.com/neuronets/nobrainer/pull/228) ([@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#227](https://github.com/neuronets/nobrainer/pull/227) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - Create unet_lstm.py [#207](https://github.com/neuronets/nobrainer/pull/207) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra) [@Hoda1394](https://github.com/Hoda1394)) - [pre-commit.ci] pre-commit autoupdate [#226](https://github.com/neuronets/nobrainer/pull/226) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#225](https://github.com/neuronets/nobrainer/pull/225) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#224](https://github.com/neuronets/nobrainer/pull/224) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#223](https://github.com/neuronets/nobrainer/pull/223) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [WIP] Updating and simplifying the end-user API of Nobrainer [#215](https://github.com/neuronets/nobrainer/pull/215) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@Hoda1394](https://github.com/Hoda1394) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#220](https://github.com/neuronets/nobrainer/pull/220) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#218](https://github.com/neuronets/nobrainer/pull/218) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - Hard-Coded Scaling is removed for ProgressiveGANs, GanTrainer and ProgressiveAE Training and Model files [#209](https://github.com/neuronets/nobrainer/pull/209) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - generation api [#216](https://github.com/neuronets/nobrainer/pull/216) ([@Hoda1394](https://github.com/Hoda1394) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [WIP] segmentation api [#213](https://github.com/neuronets/nobrainer/pull/213) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@Hoda1394](https://github.com/Hoda1394) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - file name corrected for bayesian_meshnet.py [#211](https://github.com/neuronets/nobrainer/pull/211) ([@Hoda1394](https://github.com/Hoda1394) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#212](https://github.com/neuronets/nobrainer/pull/212) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Adding vox2vox model [#205](https://github.com/neuronets/nobrainer/pull/205) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@Hoda1394](https://github.com/Hoda1394)) - added tensorflow-addons as install dependency [#206](https://github.com/neuronets/nobrainer/pull/206) ([@Hoda1394](https://github.com/Hoda1394)) - Composable Data Augmentations [#189](https://github.com/neuronets/nobrainer/pull/189) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) #### ⚠️ Pushed to `master` - update dockerfiles ([@satra](https://github.com/satra)) #### Authors: 4 - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Aakanksha Rana ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana)) - Hoda Rajaei ([@Hoda1394](https://github.com/Hoda1394)) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 0.3.0 (Tue Jan 11 2022) #### 🚀 Enhancement - Update README.md [#203](https://github.com/neuronets/nobrainer/pull/203) ([@satra](https://github.com/satra)) #### 🐛 Bug Fix - Progressive auto encoder [#196](https://github.com/neuronets/nobrainer/pull/196) (alice.bizeul@inf.ethz.ch [@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#197](https://github.com/neuronets/nobrainer/pull/197) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### ⚠️ Pushed to `master` - Update README.md ([@satra](https://github.com/satra)) #### Authors: 3 - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Alice Elsa Marie Bizeul (alice.bizeul@inf.ethz.ch) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 0.2.1 (Fri Dec 24 2021) #### 🐛 Bug Fix - fix: update docker files and citation pointer [#194](https://github.com/neuronets/nobrainer/pull/194) ([@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### Authors: 2 - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 0.2.0 (Fri Dec 24 2021) #### 🚀 Enhancement - Update README.md and move transforms [#187](https://github.com/neuronets/nobrainer/pull/187) ([@Hoda1394](https://github.com/Hoda1394) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) #### 🐛 Bug Fix - Documentation update [#184](https://github.com/neuronets/nobrainer/pull/184) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Correction of G_dice and multi_class dice_calculation [#164](https://github.com/neuronets/nobrainer/pull/164) ([@kaczmarj](https://github.com/kaczmarj) [@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - DOC: Update link to CITATION file [#190](https://github.com/neuronets/nobrainer/pull/190) ([@arokem](https://github.com/arokem)) - Bayesian: 3D bayes_by_backprop_layer + Distributed Weight Consolidation [#185](https://github.com/neuronets/nobrainer/pull/185) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#188](https://github.com/neuronets/nobrainer/pull/188) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - Fix tf & tfp compatibility with python version [#186](https://github.com/neuronets/nobrainer/pull/186) ([@Hoda1394](https://github.com/Hoda1394) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#183](https://github.com/neuronets/nobrainer/pull/183) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - Added Clipping constraint and KLD loss with CD [#180](https://github.com/neuronets/nobrainer/pull/180) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - ENH: Allow multichannel input [#177](https://github.com/neuronets/nobrainer/pull/177) ([@richford](https://github.com/richford) [@satra](https://github.com/satra)) - Add citation [#181](https://github.com/neuronets/nobrainer/pull/181) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#182](https://github.com/neuronets/nobrainer/pull/182) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - adding brain siamese network models and test cases [#172](https://github.com/neuronets/nobrainer/pull/172) ([@dhritimandas](https://github.com/dhritimandas) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Added dcgan architecture with tests [#66](https://github.com/neuronets/nobrainer/pull/66) ([@wazeerzulfikar](https://github.com/wazeerzulfikar) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - BF: Add block_length param to dataset.interleave [#175](https://github.com/neuronets/nobrainer/pull/175) ([@richford](https://github.com/richford) [@satra](https://github.com/satra)) - Weightnorm feature for Variational layer [#166](https://github.com/neuronets/nobrainer/pull/166) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#179](https://github.com/neuronets/nobrainer/pull/179) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#178](https://github.com/neuronets/nobrainer/pull/178) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#173](https://github.com/neuronets/nobrainer/pull/173) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#170](https://github.com/neuronets/nobrainer/pull/170) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#165](https://github.com/neuronets/nobrainer/pull/165) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#163](https://github.com/neuronets/nobrainer/pull/163) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - fix: readme to remove qualified install [#162](https://github.com/neuronets/nobrainer/pull/162) ([@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - fix: remove unnecessary step in release action [#158](https://github.com/neuronets/nobrainer/pull/158) ([@satra](https://github.com/satra)) #### ⚠️ Pushed to `master` - Update .zenodo.json ([@satra](https://github.com/satra)) #### Authors: 9 - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Aakanksha Rana ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana)) - Adam Richie-Halford ([@richford](https://github.com/richford)) - Ariel Rokem ([@arokem](https://github.com/arokem)) - Dhritiman Das ([@dhritimandas](https://github.com/dhritimandas)) - Hoda Rajaei ([@Hoda1394](https://github.com/Hoda1394)) - Jakub Kaczmarzyk ([@kaczmarj](https://github.com/kaczmarj)) - Satrajit Ghosh ([@satra](https://github.com/satra)) - Wazeer Zulfikar ([@wazeerzulfikar](https://github.com/wazeerzulfikar)) --- # 0.1.1 (Tue Jun 22 2021) #### 🐛 Bug Fix - fix: replace key retrieval and normalizers [#157](https://github.com/neuronets/nobrainer/pull/157) ([@satra](https://github.com/satra)) - fix: separate auto release and publish to pypi [#156](https://github.com/neuronets/nobrainer/pull/156) ([@satra](https://github.com/satra)) - [pre-commit.ci] pre-commit autoupdate [#154](https://github.com/neuronets/nobrainer/pull/154) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) #### ⚠️ Pushed to `master` - fix: add twine upload to release ([@satra](https://github.com/satra)) #### Authors: 2 - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Satrajit Ghosh ([@satra](https://github.com/satra)) --- # 0.1.0 (Sat Jun 19 2021) #### 🚀 Enhancement - fix: change nobrainer models repo, readme, and guide notebooks [#150](https://github.com/neuronets/nobrainer/pull/150) ([@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - Enh/docker [#148](https://github.com/neuronets/nobrainer/pull/148) ([@satra](https://github.com/satra)) - fix: Update release.yml to handle branch protection [#147](https://github.com/neuronets/nobrainer/pull/147) ([@satra](https://github.com/satra)) #### 🐛 Bug Fix - add zenodo file [#153](https://github.com/neuronets/nobrainer/pull/153) ([@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - enh: change standardize in dataset creation to a callable [#152](https://github.com/neuronets/nobrainer/pull/152) ([@satra](https://github.com/satra)) - ENH: Bayesian neural network architectures with training requisites [#126](https://github.com/neuronets/nobrainer/pull/126) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - fix: standardize not imported for dataset creation [#151](https://github.com/neuronets/nobrainer/pull/151) ([@satra](https://github.com/satra)) - add estimator prediction function for kwyk model [#149](https://github.com/neuronets/nobrainer/pull/149) ([@Hoda1394](https://github.com/Hoda1394) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) [@satra](https://github.com/satra)) - Fix for the generate cli pytest and generation guide notebook [#146](https://github.com/neuronets/nobrainer/pull/146) ([@wazeerzulfikar](https://github.com/wazeerzulfikar)) - Created using Colaboratory [#144](https://github.com/neuronets/nobrainer/pull/144) ([@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - [pre-commit.ci] pre-commit autoupdate [#142](https://github.com/neuronets/nobrainer/pull/142) ([@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - FIX: keeping compatibility with official tensorflow docker images [#141](https://github.com/neuronets/nobrainer/pull/141) ([@satra](https://github.com/satra)) - Add bayesian prediction [#128](https://github.com/neuronets/nobrainer/pull/128) ([@Hoda1394](https://github.com/Hoda1394) [@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - ENH: Transform composition [#113](https://github.com/neuronets/nobrainer/pull/113) ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana) [@satra](https://github.com/satra) [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot])) - fixes https://github.com/neuronets/nobrainer/issues/124 [#125](https://github.com/neuronets/nobrainer/pull/125) ([@shashankbansal6](https://github.com/shashankbansal6) [@satra](https://github.com/satra)) - fix: update dockerfiles to tensorflow 2.5.0 [#140](https://github.com/neuronets/nobrainer/pull/140) ([@satra](https://github.com/satra)) - ENH: Add progressive gan to nobrainer [#138](https://github.com/neuronets/nobrainer/pull/138) ([@wazeerzulfikar](https://github.com/wazeerzulfikar) [@satra](https://github.com/satra) [@djarecka](https://github.com/djarecka)) - add: release mechanism [#139](https://github.com/neuronets/nobrainer/pull/139) ([@satra](https://github.com/satra)) - Add progressiveGAN for 3D brain MR images [#114](https://github.com/neuronets/nobrainer/pull/114) ([@wazeerzulfikar](https://github.com/wazeerzulfikar)) - Enh/notebooks [#137](https://github.com/neuronets/nobrainer/pull/137) ([@satra](https://github.com/satra)) - add CI workflow badge [#133](https://github.com/neuronets/nobrainer/pull/133) ([@kaczmarj](https://github.com/kaczmarj)) - move CI to github actions [#131](https://github.com/neuronets/nobrainer/pull/131) ([@Hoda1394](https://github.com/Hoda1394)) - Enh/minor updates [#123](https://github.com/neuronets/nobrainer/pull/123) ([@kaczmarj](https://github.com/kaczmarj)) - export `LC_ALL` and `LANG` + use tensorflow 2.3.1 [#118](https://github.com/neuronets/nobrainer/pull/118) ([@kaczmarj](https://github.com/kaczmarj)) - use specific version of cloudpickle to fix import error [#108](https://github.com/neuronets/nobrainer/pull/108) ([@kaczmarj](https://github.com/kaczmarj)) - force reinstall of python dependencies in travis ci [#102](https://github.com/neuronets/nobrainer/pull/102) ([@kaczmarj](https://github.com/kaczmarj)) #### ⚠️ Pushed to `master` - fix: use token for auto with repo access ([@satra](https://github.com/satra)) - Created using Colaboratory ([@satra](https://github.com/satra)) - Update release.yml ([@satra](https://github.com/satra)) #### Authors: 8 - [@Hoda1394](https://github.com/Hoda1394) - [@pre-commit-ci[bot]](https://github.com/pre-commit-ci[bot]) - Aakanksha Rana ([@Aakanksha-Rana](https://github.com/Aakanksha-Rana)) - Dorota Jarecka ([@djarecka](https://github.com/djarecka)) - Jakub Kaczmarzyk ([@kaczmarj](https://github.com/kaczmarj)) - Satrajit Ghosh ([@satra](https://github.com/satra)) - Shashank Bansal ([@shashankbansal6](https://github.com/shashankbansal6)) - Wazeer Zulfikar ([@wazeerzulfikar](https://github.com/wazeerzulfikar)) ================================================ FILE: CITATION ================================================ Please follow this DOI (https://doi.org/10.5281/zenodo.4995077) to find the latest citation on Zenodo. The different citation formats are available in the Share and Export sections of the page. On a desktop browser these are on the bottom right of the page. ================================================ FILE: CLAUDE.md ================================================ # Nobrainer Development Guidelines ## Project Overview Nobrainer is a PyTorch-based deep learning library for 3D brain MRI segmentation. It provides scikit-learn-style estimators (`Segmentation`, `Generation`), Bayesian models (VWN/FFG, Pyro-based), and a comprehensive data pipeline (MONAI transforms, Zarr3 stores, SynthSeg generation). ## Technology Stack - **Python**: 3.12+; CI matrix 3.12/3.13/3.14 - **Package management**: `uv` throughout (never pip/conda/poetry) - **ML framework**: PyTorch >= 2.0 - **Medical imaging**: MONAI >= 1.3 (transforms, losses, metrics, model wrappers) - **Bayesian**: Pyro-ppl >= 1.9 (optional `[bayesian]` extra) - **Data**: Zarr >= 3.0 (optional `[zarr]` extra), NIfTI via nibabel - **Testing**: pytest; pre-commit (black, flake8, isort, codespell) - **CI**: GitHub Actions; EC2 GPU runner for GPU tests ## Commands ```bash # Install uv pip install -e ".[all]" # Test (CPU) uv run pytest nobrainer/tests/unit/ -m "not gpu" --tb=short # SR-tests (somewhat realistic, need sample brain data) uv run pytest nobrainer/sr-tests/ -m "not gpu" # Lint uv run pre-commit run --all-files ``` ## Code Conventions - All models: `(B, C, D, H, W)` input → `(B, n_classes, D, H, W)` output - Factory functions: `model_name(n_classes=1, in_channels=1, **kwargs) -> nn.Module` - Bayesian models: `supports_mc = True` class attribute; `forward(x, **kwargs)` accepts `mc=True/False` - Prediction: use `model_supports_mc(model)` to check, never `try/except TypeError` - Labels: always squeeze channel dim + cast to `long` before `CrossEntropyLoss` - Device selection: `nobrainer.gpu.get_device()` (CUDA > MPS > CPU) - Data augmentation: `TrainableCompose` wraps MONAI Compose; `Augmentation()` wrapper auto-skips during predict ## Key Modules | Module | Purpose | |--------|---------| | `models/` | MeshNet, SegFormer3D, UNet, SwinUNETR, SegResNet, Bayesian variants | | `processing/` | Segmentation/Generation estimators, Dataset builder | | `augmentation/` | SynthSeg generator, TrainableCompose, profiles | | `datasets/` | OpenNeuro fetching, Zarr3 store management | | `training.py` | `fit()` with DDP, AMP, validation, callbacks | | `prediction.py` | Block-based predict, strided reassembly, MC uncertainty | | `losses.py` | Dice, FocalLoss, DiceCE, ELBO, class weights | | `gpu.py` | Device detection, auto batch size, multi-GPU scaling | | `slurm.py` | SLURM preemption handler, checkpoint/resume | | `experiment.py` | Local JSONL/CSV + optional W&B tracking | ## Development Workflow (Speckit Constitution) When working on new features or significant changes, follow these principles: ### I. Specification-First Every feature MUST begin with a written specification before implementation: - Prioritized user stories with independently testable acceptance scenarios - Functional requirements written as verifiable constraints (MUST/SHOULD) - Measurable success criteria that are technology-agnostic ### II. Incremental Planning Plans are built in ordered phases — no phase may be skipped: - **Phase 0 — Research**: Resolve all unknowns before design - **Phase 1 — Design**: Data model, interface contracts, quickstart documented - **Phase 2 — Tasks**: Actionable task list organized by user story priority Implementation MUST NOT begin until tasks exist. ### III. Independent User-Story Delivery - Each P1 story MUST produce a viable MVP with standalone value - Stories MUST NOT have hard runtime dependencies on lower-priority stories - Tasks MUST be labeled with their owning story (`[US1]`, `[US2]`, etc.) ### IV. Constitution Compliance Gate Every plan MUST include a Constitution Check evaluated before research and after design. Violations MUST be justified with a simpler alternative explicitly rejected. ### V. Simplicity & YAGNI - Prefer the simplest architecture that satisfies current user stories - Do not introduce abstractions for hypothetical future requirements - Complexity MUST be justified against a concrete, present need ### VI. Git Commit Discipline - Feature work on dedicated branches (`###-feature-name`) - Planning artifacts committed after each speckit command - Each completed task results in at least one commit - Prefer new commits over amending ### VII. Technology Stack Standards - Python: `uv` for all environment and package management - Containers: Docker only - No substitutions without justified amendment ## Quality Gates | Gate | Condition | |------|-----------| | G1 | spec.md has ≥1 user story with acceptance scenarios | | G2 | All NEEDS CLARIFICATION resolved before design | | G3 | Constitution Check passes (or violations justified) | | G4 | tasks.md exists and all tasks reference a user story | | G5 | P1 story independently verified before P2 work | | G6 | All planning artifacts committed to feature branch | ## Speckit Commands (if available) ``` /speckit.specify → spec.md /speckit.clarify → spec.md (revised) /speckit.plan → plan.md, research.md, data-model.md, quickstart.md /speckit.tasks → tasks.md /speckit.implement → code /speckit.analyze → consistency report ``` If speckit is not installed, follow the principles above manually. ================================================ FILE: LICENSE ================================================ Copyright 2021 The Nobrainer Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ # This line includes versioneer.py in sdists, which is necessary for wheels # built from sdists to have the version set in their metadata. include versioneer.py include CHANGELOG.md tox.ini graft nobrainer global-exclude *.py[cod] include nobrainer/_version.py ================================================ FILE: README.md ================================================ # Nobrainer ![Build status](https://github.com/neuronets/nobrainer/actions/workflows/ci.yml/badge.svg) _Nobrainer_ is a deep learning framework for 3D brain image processing built on **PyTorch** and **MONAI**. It provides segmentation models (deterministic and Bayesian), generative models, a MONAI-native data pipeline, block-based prediction with uncertainty quantification, and CLI tools for inference and automated hyperparameter search. Pre-trained models for brain extraction, segmentation, and generation are available in the [trained-models](https://github.com/neuronets/trained-models) repository. The _Nobrainer_ project is supported by NIH RF1MH121885 and is distributed under the Apache 2.0 license. ## Models ### Segmentation | Model | Backend | Application | |-------|---------|-------------| | [UNet](nobrainer/models/segmentation.py) | MONAI | segmentation | | [VNet](nobrainer/models/segmentation.py) | MONAI | segmentation | | [Attention U-Net](nobrainer/models/segmentation.py) | MONAI | segmentation | | [UNETR](nobrainer/models/segmentation.py) | MONAI | segmentation | | [MeshNet](nobrainer/models/meshnet.py) | PyTorch | segmentation | | [HighResNet](nobrainer/models/highresnet.py) | PyTorch | segmentation | ### Bayesian (uncertainty quantification) | Model | Backend | Application | |-------|---------|-------------| | [Bayesian VNet](nobrainer/models/bayesian/bayesian_vnet.py) | Pyro | segmentation + uncertainty | | [Bayesian MeshNet](nobrainer/models/bayesian/bayesian_meshnet.py) | Pyro | segmentation + uncertainty | ### Generative | Model | Backend | Application | |-------|---------|-------------| | [Progressive GAN](nobrainer/models/generative/progressivegan.py) | PyTorch Lightning | brain generation | | [DCGAN](nobrainer/models/generative/dcgan.py) | PyTorch Lightning | brain generation | ### Other | Model | Application | |-------|-------------| | [Autoencoder](nobrainer/models/autoencoder.py) | representation learning | | [SimSiam](nobrainer/models/simsiam.py) | self-supervised learning | ### Custom layers - `BernoulliDropout`, `ConcreteDropout`, `GaussianDropout` — stochastic regularization - `BayesianConv3d`, `BayesianLinear` — Pyro-based weight uncertainty layers - `MaxPool4D` — 4D max pooling via reshape ### Losses and metrics **Losses**: Dice, Generalized Dice, Jaccard, Tversky, ELBO (Bayesian), Wasserstein, Gradient Penalty **Metrics**: Dice, Jaccard, Hausdorff distance (all via MONAI) ## Installation ### pip / uv ```bash uv venv --python 3.14 source .venv/bin/activate uv pip install nobrainer ``` For Bayesian and generative model support: ```bash uv pip install "nobrainer[bayesian,generative]" monai pyro-ppl ``` ### Docker GPU image (requires NVIDIA driver on host): ```bash docker pull neuronets/nobrainer:latest-gpu-pt docker run --gpus all --rm neuronets/nobrainer:latest-gpu-pt predict --help ``` CPU-only image: ```bash docker pull neuronets/nobrainer:latest-cpu-pt docker run --rm neuronets/nobrainer:latest-cpu-pt predict --help ``` ## Quick start ### Tutorials See the [Nobrainer Book](https://neuronets.dev/nobrainer-book/) for 11 progressive tutorials — from installation to contributing. ### sr-tests (somewhat realistic tests) `nobrainer/sr-tests/` contains pytest integration tests that exercise the real API with real brain data. They run in CI on every push: ```bash pytest nobrainer/sr-tests/ -v -m "not gpu" --tb=short ``` ### Simple API (3 lines) ```python from nobrainer.processing import Segmentation, Dataset ds = Dataset.from_files(filepaths, block_shape=(128, 128, 128), n_classes=2).batch(2) result = Segmentation("unet").fit(ds, epochs=5).predict("brain.nii.gz") ``` Models are saved with [Croissant-ML](https://mlcommons.org/croissant/) metadata for reproducibility: ```python seg.save("my_model") # Creates model.pth + croissant.json seg = Segmentation.load("my_model") ``` ### Brain segmentation (CLI) ```bash nobrainer predict \ --model unet_brainmask.pth \ --model-type unet \ --n-classes 2 \ input_T1w.nii.gz output_mask.nii.gz ``` ### Brain segmentation (Python) ```python import torch import nobrainer from nobrainer.prediction import predict model = nobrainer.models.unet(n_classes=2) model.load_state_dict(torch.load("unet_brainmask.pth")) model.eval() result = predict( inputs="input_T1w.nii.gz", model=model, block_shape=(128, 128, 128), device="cuda", ) result.to_filename("output_mask.nii.gz") ``` ### Bayesian inference with uncertainty maps ```python from nobrainer.prediction import predict_with_uncertainty model = nobrainer.models.bayesian_vnet(n_classes=2) model.load_state_dict(torch.load("bayesian_vnet.pth")) label, variance, entropy = predict_with_uncertainty( inputs="input_T1w.nii.gz", model=model, n_samples=10, block_shape=(128, 128, 128), device="cuda", ) label.to_filename("label.nii.gz") variance.to_filename("variance.nii.gz") entropy.to_filename("entropy.nii.gz") ``` ### Brain generation ```bash nobrainer generate \ --model progressivegan.ckpt \ --model-type progressivegan \ output_synthetic.nii.gz ``` ### Zarr v3 data pipeline ```python from nobrainer.io import nifti_to_zarr, zarr_to_nifti # Convert NIfTI to sharded Zarr v3 with multi-resolution pyramid nifti_to_zarr("brain_T1w.nii.gz", "brain.zarr", chunk_shape=(64, 64, 64), levels=3) # Load Zarr stores directly in the training pipeline from nobrainer.dataset import get_dataset loader = get_dataset( data=[{"image": "brain.zarr", "label": "label.zarr"}], batch_size=2, ) # Round-trip back to NIfTI zarr_to_nifti("brain.zarr", "brain_roundtrip.nii.gz") ``` ### Training a model ```python import torch from nobrainer.dataset import get_dataset from nobrainer.losses import dice data_files = [ {"image": f"sub-{i:03d}_T1w.nii.gz", "label": f"sub-{i:03d}_label.nii.gz"} for i in range(1, 101) ] loader = get_dataset(data=data_files, batch_size=2, augment=True, cache=True) model = nobrainer.models.unet(n_classes=2).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) criterion = dice() for epoch in range(50): model.train() for batch in loader: images, labels = batch["image"].cuda(), batch["label"].cuda() optimizer.zero_grad() loss = criterion(model(images), labels) loss.backward() optimizer.step() torch.save(model.state_dict(), "unet_trained.pth") ``` ## Automated research (autoresearch) Nobrainer includes an automated hyperparameter search loop that uses an LLM to propose training modifications overnight: ```bash nobrainer research run \ --working-dir ./research/bayesian_vnet \ --model-family bayesian_vnet \ --max-experiments 15 \ --budget-hours 8 ``` Improved models are versioned via DataLad: ```bash nobrainer research commit \ --run-dir ./research/bayesian_vnet \ --trained-models-path ~/trained-models \ --model-family bayesian_vnet ``` ## GPU test dispatch (nobrainer-runner) [nobrainer-runner](https://github.com/neuronets/nobrainer-runner) submits GPU test suites to Slurm clusters or cloud instances (AWS Batch, GCP Batch): ```bash nobrainer-runner submit --profile mycluster --gpus 1 "pytest tests/ -m gpu" nobrainer-runner status $JOB_ID nobrainer-runner results --format json $JOB_ID ``` ## Package layout - `nobrainer.models` — segmentation, Bayesian, and generative `torch.nn.Module` models - `nobrainer.losses` — Dice, Jaccard, Tversky, ELBO, Wasserstein (MONAI-backed) - `nobrainer.metrics` — Dice, Jaccard, Hausdorff (MONAI-backed) - `nobrainer.dataset` — MONAI `CacheDataset` + `DataLoader` pipeline - `nobrainer.prediction` — block-based `predict()` and `predict_with_uncertainty()` - `nobrainer.io` — `convert_tfrecords()`, `convert_weights()` (TF → PyTorch migration) - `nobrainer.layers` — dropout layers, Bayesian layers, MaxPool4D - `nobrainer.research` — autoresearch loop and DataLad model versioning - `nobrainer.cli` — Click CLI (`predict`, `generate`, `research`, `commit`, `info`) ## Development and releases Nobrainer uses a two-branch release workflow: | Branch | Purpose | PyPI version | |--------|---------|--------------| | `master` | Stable releases | `uv pip install nobrainer` | | `alpha` | Pre-releases for testing | `uv pip install --pre nobrainer` | **Alpha workflow**: Feature branches merge to `alpha`. Each merge triggers book tutorial validation (using a matching branch on [nobrainer-book](https://github.com/neuronets/nobrainer-book) if available, otherwise the book's `alpha` branch) followed by an automatic pre-release tag (e.g., `0.5.0-alpha.0`). **Stable workflow**: When `alpha` is merged to `master` with the `release` label, a stable version is tagged and published to PyPI. **GPU CI**: PRs to `master` can request GPU testing on EC2 by adding the `gpu-test-approved` label. Instance type and spot pricing are configurable via `gpu-instance:` and `gpu-spot:true` labels. ## Citation If you use this package, please [cite](https://github.com/neuronets/nobrainer/blob/master/CITATION) it. ## Questions or issues Please [submit a GitHub issue](https://github.com/neuronets/helpdesk/issues/new/choose). ================================================ FILE: conftest.py ================================================ """Root conftest.py — auto-skip GPU tests when CUDA is unavailable.""" from __future__ import annotations import pytest import torch def pytest_collection_modifyitems(config, items): """Skip tests marked with @pytest.mark.gpu when CUDA is not available.""" if torch.cuda.is_available(): return skip_gpu = pytest.mark.skip(reason="CUDA not available — skipping GPU test") for item in items: if item.get_closest_marker("gpu"): item.add_marker(skip_gpu) ================================================ FILE: docker/README.md ================================================ # Nobrainer in a container The Dockerfiles in this directory can be used to create Docker images to use _Nobrainer_ on CPU or GPU. ## Build images ```bash cd /code/nobrainer # Top-level nobrainer directory docker build -t neuronets/nobrainer:master-cpu -f docker/cpu.Dockerfile . docker build -t neuronets/nobrainer:master-gpu -f docker/gpu.Dockerfile . ``` # Convert Docker images to Singularity containers Using Singularity version 3.x, Docker images can be converted to Singularity containers using the `singularity` command-line tool. ## Pulling from DockerHub In most cases (e.g., working on a HPC cluster), the _Nobrainer_ singularity container can be created with: ```bash singularity pull docker://neuronets/nobrainer:master-gpu ``` ## Building from local Docker cache If you built a _Nobrainer_ Docker images locally and would like to convert it to a Singularity container, you can do so with: ```bash sudo singularity pull docker-daemon://neuronets/nobrainer:master-gpu ``` Please note the use of `sudo` here. This is necessary for interacting with the Docker daemon. ================================================ FILE: docker/cpu.Dockerfile ================================================ FROM python:3.14-slim ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ git \ && rm -rf /var/lib/apt/lists/* COPY [".", "/opt/nobrainer"] RUN pip install --no-cache-dir uv \ && uv pip install --system \ "torch" \ "/opt/nobrainer[bayesian,generative]" \ monai \ pyro-ppl \ --index-url https://download.pytorch.org/whl/cpu \ --extra-index-url https://pypi.org/simple \ && rm -rf /root/.cache/uv ENV LC_ALL=C.UTF-8 \ LANG=C.UTF-8 WORKDIR "/work" LABEL maintainer="Satrajit Ghosh " LABEL org.opencontainers.image.title="nobrainer-cpu-pytorch" LABEL org.opencontainers.image.description="nobrainer with PyTorch CPU-only support" ENTRYPOINT ["nobrainer"] ================================================ FILE: docker/gpu.Dockerfile ================================================ FROM python:3.14-slim ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ git \ && rm -rf /var/lib/apt/lists/* COPY [".", "/opt/nobrainer"] RUN pip install --no-cache-dir uv \ && uv pip install --system \ torch \ "/opt/nobrainer[bayesian,generative,versioning]" \ monai \ pyro-ppl \ && rm -rf /root/.cache/uv ENV LC_ALL=C.UTF-8 \ LANG=C.UTF-8 WORKDIR "/work" LABEL maintainer="Satrajit Ghosh " LABEL org.opencontainers.image.title="nobrainer-gpu-pytorch" LABEL org.opencontainers.image.description="nobrainer with PyTorch GPU support (CUDA via host driver)" ENTRYPOINT ["nobrainer"] ================================================ FILE: nobrainer/__init__.py ================================================ try: from ._version import __version__ # noqa: F401 except (ImportError, ModuleNotFoundError): try: from . import _version # noqa: F401 __version__ = _version.get_versions()["version"] except (ImportError, AttributeError): __version__ = "0.0.0.dev0" # Lazy imports: submodules are available via nobrainer.io, nobrainer.models, etc. # but are not eagerly loaded to avoid requiring optional dependencies (monai, # pyro-ppl, pytorch-lightning) at import time. ================================================ FILE: nobrainer/_version.py ================================================ # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. # This file is released into the public domain. Generated by # versioneer-0.21 (https://github.com/python-versioneer/python-versioneer) """Git implementation of _version.py.""" import errno import os import re import subprocess import sys from typing import Callable, Dict def get_keywords(): """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must # each be defined on a line of their own. _version.py will just call # get_keywords(). git_refnames = "$Format:%d$" git_full = "$Format:%H$" git_date = "$Format:%ci$" keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} return keywords class VersioneerConfig: """Container for Versioneer configuration parameters.""" def get_config(): """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py cfg = VersioneerConfig() cfg.VCS = "git" cfg.style = "pep440" cfg.tag_prefix = "" cfg.parentdir_prefix = "" cfg.versionfile_source = "nobrainer/_version.py" cfg.verbose = False return cfg class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" LONG_VERSION_PY: Dict[str, str] = {} HANDLERS: Dict[str, Dict[str, Callable]] = {} def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f return decorate def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None for command in commands: try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git process = subprocess.Popen( [command] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), ) break except OSError: e = sys.exc_info()[1] if e.errno == errno.ENOENT: continue if verbose: print("unable to run %s" % dispcmd) print(e) return None, None else: if verbose: print("unable to find command, tried %s" % (commands,)) return None, None stdout = process.communicate()[0].strip().decode() if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) return None, process.returncode return stdout, process.returncode def versions_from_parentdir(parentdir_prefix, root, verbose): """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both the project name and a version string. We will also support searching up two directory levels for an appropriately named parent directory """ rootdirs = [] for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return { "version": dirname[len(parentdir_prefix) :], "full-revisionid": None, "dirty": False, "error": None, "date": None, } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: print( "Tried directories %s but none started with prefix %s" % (str(rootdirs), parentdir_prefix) ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @register_vcs_handler("git", "get_keywords") def git_get_keywords(versionfile_abs): """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. keywords = {} try: with open(versionfile_abs, "r") as fobj: for line in fobj: if line.strip().startswith("git_refnames ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["refnames"] = mo.group(1) if line.strip().startswith("git_full ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["full"] = mo.group(1) if line.strip().startswith("git_date ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["date"] = mo.group(1) except OSError: pass return keywords @register_vcs_handler("git", "keywords") def git_versions_from_keywords(keywords, tag_prefix, verbose): """Get version information from git keywords.""" if "refnames" not in keywords: raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: # Use only the last line. Previous lines may contain GPG signature # information. date = date.splitlines()[-1] # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because # it's been around since git-1.5.3, and it's too difficult to # discover which version we're using, or to work around using an # older one. date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) refnames = keywords["refnames"].strip() if refnames.startswith("$Format"): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d # expansion behaves like git log --decorate=short and strips out the # refs/heads/ and refs/tags/ prefixes that would let us distinguish # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: print("likely tags: %s" % ",".join(sorted(tags))) for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) return { "version": r, "full-revisionid": keywords["full"].strip(), "dirty": False, "error": None, "date": date, } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") return { "version": "0+unknown", "full-revisionid": keywords["full"].strip(), "dirty": False, "error": "no suitable tags", "date": None, } @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* expanded, and _version.py hasn't already been rewritten with a short version string, meaning we're inside a checked out source tree. """ GITS = ["git"] TAG_PREFIX_REGEX = "*" if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] TAG_PREFIX_REGEX = r"\*" _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) raise NotThisMethod("'git rev-parse --git-dir' returned error") # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) describe_out, rc = runner( GITS, [ "describe", "--tags", "--dirty", "--always", "--long", "--match", "%s%s" % (tag_prefix, TAG_PREFIX_REGEX), ], cwd=root, ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() pieces = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") branch_name = branch_name.strip() if branch_name == "HEAD": # If we aren't exactly on a branch, pick a branch which represents # the current commit. If all else fails, we are on a branchless # commit. branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) # --contains was added in git-1.5.4 if rc != 0 or branches is None: raise NotThisMethod("'git branch --contains' returned error") branches = branches.split("\n") # Remove the first line if we're running detached if "(" in branches[0]: branches.pop(0) # Strip off the leading "* " from the list of branches. branches = [branch[2:] for branch in branches] if "master" in branches: branch_name = "master" elif not branches: branch_name = None else: # Pick the first branch that is returned. Good or bad. branch_name = branches[0] pieces["branch"] = branch_name # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out # look for -dirty suffix dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag full_tag = mo.group(1) if not full_tag.startswith(tag_prefix): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( full_tag, tag_prefix, ) return pieces pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) # commit: short hex revision ID pieces["short"] = mo.group(3) else: # HEX: no tags pieces["closest-tag"] = None count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() # Use only the last line. Previous lines may contain GPG signature # information. date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces def plus_or_dot(pieces): """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" def render_pep440(pieces): """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty Exceptions: 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += plus_or_dot(pieces) rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_branch(pieces): """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . The ".dev0" means not master branch. Note that .dev0 sorts backwards (a feature branch will appear "older" than the master branch). Exceptions: 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: if pieces["branch"] != "master": rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def pep440_split_post(ver): """Split pep440 version string at the post-release segment. Returns the release segments before the post-release and the post-release version number (or -1 if no post-release segment is present). """ vc = str.split(ver, ".post") return vc[0], int(vc[1] or 0) if len(vc) == 2 else None def render_pep440_pre(pieces): """TAG[.postN.devDISTANCE] -- No -dirty. Exceptions: 1: no tags. 0.post0.devDISTANCE """ if pieces["closest-tag"]: if pieces["distance"]: # update the post release segment tag_version, post_version = pep440_split_post(pieces["closest-tag"]) rendered = tag_version if post_version is not None: rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) else: rendered += ".post0.dev%d" % (pieces["distance"]) else: # no commits, use the tag as the version rendered = pieces["closest-tag"] else: # exception #1 rendered = "0.post0.dev%d" % pieces["distance"] return rendered def render_pep440_post(pieces): """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards (a dirty tree will appear "older" than the corresponding clean one), but you shouldn't be releasing software with -dirty anyways. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%s" % pieces["short"] else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += "+g%s" % pieces["short"] return rendered def render_pep440_post_branch(pieces): """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . The ".dev0" means not master branch. Exceptions: 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["branch"] != "master": rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%s" % pieces["short"] if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["branch"] != "master": rendered += ".dev0" rendered += "+g%s" % pieces["short"] if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_old(pieces): """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" return rendered def render_git_describe(pieces): """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"]: rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render_git_describe_long(pieces): """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. The distance/hash is unconditional. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: return { "version": "unknown", "full-revisionid": pieces.get("long"), "dirty": None, "error": pieces["error"], "date": None, } if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) elif style == "pep440-branch": rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) elif style == "pep440-post-branch": rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": rendered = render_git_describe(pieces) elif style == "git-describe-long": rendered = render_git_describe_long(pieces) else: raise ValueError("unknown style '%s'" % style) return { "version": rendered, "full-revisionid": pieces["long"], "dirty": pieces["dirty"], "error": None, "date": pieces.get("date"), } def get_versions(): """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which # case we can only use expanded keywords. cfg = get_config() verbose = cfg.verbose try: return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass try: root = os.path.realpath(__file__) # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: return { "version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to find root of source tree", "date": None, } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) return render(pieces, cfg.style) except NotThisMethod: pass try: if cfg.parentdir_prefix: return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) except NotThisMethod: pass return { "version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to compute version", "date": None, } ================================================ FILE: nobrainer/augmentation/__init__.py ================================================ """Data augmentation: transform tagging, profiles, and SynthSeg generation.""" from .profiles import get_augmentation_profile from .synthseg import SynthSegGenerator from .transforms import Augmentation, TrainableCompose __all__ = [ "Augmentation", "SynthSegGenerator", "TrainableCompose", "get_augmentation_profile", ] ================================================ FILE: nobrainer/augmentation/profiles.py ================================================ """Predefined augmentation profiles for brain imaging. Each profile returns a list of MONAI dictionary transforms wrapped with :class:`~nobrainer.augmentation.transforms.Augmentation` so they are automatically skipped during inference. Profiles: ``"none"``, ``"light"``, ``"standard"``, ``"heavy"``. Usage:: from nobrainer.augmentation.profiles import get_augmentation_profile transforms = get_augmentation_profile("standard", keys=["image", "label"]) """ from __future__ import annotations from .transforms import Augmentation def get_augmentation_profile( name: str, keys: list[str] | None = None, ) -> list: """Return a list of augmentation transforms for the given profile. All returned transforms are wrapped with :class:`Augmentation` so :class:`TrainableCompose` will skip them during inference. Parameters ---------- name : str Profile name: ``"none"``, ``"light"``, ``"standard"``, ``"heavy"``. keys : list of str or None MONAI dictionary keys (default ``["image", "label"]``). Returns ------- list List of ``Augmentation``-wrapped MONAI transforms. """ from monai.transforms import RandAffined, RandFlipd, RandGaussianNoised if keys is None: keys = ["image", "label"] img_keys = [k for k in keys if k == "image"] has_label = "label" in keys modes = ["bilinear", "nearest"] if has_label else ["bilinear"] if name == "none": return [] if name == "light": return [ Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=0)), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=1)), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=2)), ] if name == "standard": return [ Augmentation( RandAffined( keys=keys, prob=0.5, rotate_range=(0.15, 0.15, 0.15), scale_range=(0.1, 0.1, 0.1), mode=modes, padding_mode="border", ) ), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=0)), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=1)), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=2)), Augmentation( RandGaussianNoised(keys=img_keys, prob=0.2, mean=0.0, std=0.1) ), ] if name == "heavy": return [ Augmentation( RandAffined( keys=keys, prob=0.8, rotate_range=(0.3, 0.3, 0.3), scale_range=(0.2, 0.2, 0.2), mode=modes, padding_mode="border", ) ), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=0)), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=1)), Augmentation(RandFlipd(keys=keys, prob=0.5, spatial_axis=2)), Augmentation( RandGaussianNoised(keys=img_keys, prob=0.5, mean=0.0, std=0.15) ), ] available = "none, light, standard, heavy" raise ValueError(f"Unknown augmentation profile '{name}'. Available: {available}") ================================================ FILE: nobrainer/augmentation/synthseg.py ================================================ """SynthSeg-style synthetic brain data generator. Enhanced implementation following Billot et al. (2023) with: - GMM tissue class grouping (labels grouped by tissue type) - Spatial augmentation (elastic deformation, rotation, scaling, flipping) - Resolution randomization (downsample + upsample) - Configurable intensity priors Reference: Billot et al., "SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining", Medical Image Analysis, 2023. """ from __future__ import annotations from pathlib import Path import nibabel as nib import numpy as np import torch import torch.utils.data class SynthSegGenerator(torch.utils.data.Dataset): """SynthSeg-style synthetic brain data generator. Generates synthetic brain images from label maps with domain randomization for contrast-agnostic training. Parameters ---------- label_maps : list of str or Path Paths to NIfTI label-map files (e.g., FreeSurfer aparc+aseg). n_samples_per_map : int Number of synthetic samples per label map. generation_classes : dict or None Tissue class grouping: ``{"WM": [2, 41], ...}``. Labels in the same class share one intensity distribution. None = use default FreeSurfer tissue classes. intensity_prior : tuple of float ``(min, max)`` bounds for sampling per-class mean intensities. std_prior : tuple of float ``(min, max)`` bounds for sampling per-class std. noise_std : float Additive Gaussian noise std. bias_field_std : float Bias field magnitude (std of polynomial coefficients). elastic_std : float Elastic deformation magnitude (0 = disabled). rotation_range : float Max rotation in degrees per axis (0 = disabled). scaling_bounds : float Max scaling fraction (e.g., 0.2 = ±20%). flipping : bool Enable random left-right flipping with label remapping. randomize_resolution : bool Simulate variable acquisition resolution. resolution_range : tuple of float ``(min_mm, max_mm)`` per-axis resolution range. """ def __init__( self, label_maps: list[str | Path], n_samples_per_map: int = 10, generation_classes: dict[str, list[int]] | None = None, intensity_prior: tuple[float, float] = (0.0, 250.0), std_prior: tuple[float, float] = (0.0, 35.0), noise_std: float = 0.1, bias_field_std: float = 0.7, elastic_std: float = 4.0, rotation_range: float = 15.0, scaling_bounds: float = 0.2, flipping: bool = True, randomize_resolution: bool = True, resolution_range: tuple[float, float] = (1.0, 3.0), seed: int | None = None, ) -> None: self.label_maps = [Path(p) for p in label_maps] self._seed = seed self.n_samples_per_map = n_samples_per_map self.intensity_prior = intensity_prior self.std_prior = std_prior self.noise_std = noise_std self.bias_field_std = bias_field_std self.elastic_std = elastic_std self.rotation_range = rotation_range self.scaling_bounds = scaling_bounds self.flipping = flipping self.randomize_resolution = randomize_resolution self.resolution_range = resolution_range # Load tissue class mapping if generation_classes is None: from nobrainer.data.tissue_classes import FREESURFER_TISSUE_CLASSES self.generation_classes = FREESURFER_TISSUE_CLASSES else: self.generation_classes = generation_classes # Build reverse lookup: label_id → class_name self._label_to_class: dict[int, str] = {} for cls_name, label_ids in self.generation_classes.items(): for lid in label_ids: self._label_to_class[lid] = cls_name def __len__(self) -> int: return len(self.label_maps) * self.n_samples_per_map def _get_rng(self, idx: int) -> np.random.Generator: """Get a seeded RNG for reproducibility, or unseeded if no seed.""" if self._seed is not None: return np.random.default_rng(self._seed + idx) return np.random.default_rng() def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: map_idx = idx // self.n_samples_per_map label_path = self.label_maps[map_idx] # Load label map label_data = np.asarray(nib.load(label_path).dataobj, dtype=np.int32) # 1. GMM intensity generation (per tissue class) image = self._generate_intensities(label_data) # 2. Spatial augmentation (elastic + affine + flip) if self.elastic_std > 0 or self.rotation_range > 0 or self.flipping: image, label_data = self._spatial_augmentation(image, label_data) # 3. Resolution randomization if self.randomize_resolution: image = self._randomize_resolution(image) # 4. Bias field if self.bias_field_std > 0: image = self._add_bias_field(image) # 5. Gaussian noise if self.noise_std > 0: image = image + np.random.normal(0, self.noise_std, image.shape).astype( np.float32 ) # Convert to tensors with channel dim [1, D, H, W] image_t = torch.from_numpy(image).float().unsqueeze(0) label_t = torch.from_numpy(label_data).long().unsqueeze(0) return {"image": image_t, "label": label_t} # ------------------------------------------------------------------ # GMM intensity generation # ------------------------------------------------------------------ def _generate_intensities(self, label_data: np.ndarray) -> np.ndarray: """Generate image by sampling GMM intensities per tissue class.""" rng = np.random.default_rng() unique_labels = np.unique(label_data) # Sample one (mean, std) per tissue class class_params: dict[str, tuple[float, float]] = {} for cls_name in self.generation_classes: mean = rng.uniform(*self.intensity_prior) std = rng.uniform(*self.std_prior) class_params[cls_name] = (mean, std) # Fill each label region from its class distribution image = np.zeros_like(label_data, dtype=np.float32) for lab in unique_labels: mask = label_data == lab n_vox = int(mask.sum()) if n_vox == 0: continue cls_name = self._label_to_class.get(lab) if cls_name is not None and cls_name in class_params: mean, std = class_params[cls_name] else: # Unknown label: sample fresh random params mean = rng.uniform(*self.intensity_prior) std = rng.uniform(*self.std_prior) image[mask] = rng.normal(mean, max(std, 1e-6), size=n_vox).astype( np.float32 ) return image # ------------------------------------------------------------------ # Spatial augmentation # ------------------------------------------------------------------ def _spatial_augmentation( self, image: np.ndarray, label: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """Apply elastic deformation, affine transform, and flipping.""" from scipy.ndimage import map_coordinates D, H, W = image.shape # Build coordinate grid coords = np.mgrid[:D, :H, :W].astype(np.float32) # (3, D, H, W) # Elastic deformation: smooth random displacement field if self.elastic_std > 0: # Sample on coarse grid, smooth, then resize coarse_shape = (max(4, D // 8), max(4, H // 8), max(4, W // 8)) rng = np.random.default_rng() for axis in range(3): displacement = rng.normal(0, self.elastic_std, coarse_shape).astype( np.float32 ) # Smooth from scipy.ndimage import gaussian_filter, zoom displacement = gaussian_filter(displacement, sigma=2.0) # Resize to full volume zoom_factors = ( D / coarse_shape[0], H / coarse_shape[1], W / coarse_shape[2], ) displacement = zoom(displacement, zoom_factors, order=1) # Crop/pad to exact shape if needed displacement = displacement[:D, :H, :W] coords[axis] += displacement # Affine: rotation + scaling if self.rotation_range > 0 or self.scaling_bounds > 0: center = np.array([D / 2, H / 2, W / 2]) coords_centered = coords.reshape(3, -1) - center[:, None] # Build rotation matrix (Euler angles) rng = np.random.default_rng() angles = rng.uniform(-self.rotation_range, self.rotation_range, size=3) angles_rad = np.deg2rad(angles) Rx = _rot_x(angles_rad[0]) Ry = _rot_y(angles_rad[1]) Rz = _rot_z(angles_rad[2]) R = Rz @ Ry @ Rx # Scaling if self.scaling_bounds > 0: scale = rng.uniform( 1 - self.scaling_bounds, 1 + self.scaling_bounds, size=3 ) S = np.diag(scale) R = R @ S coords_centered = R @ coords_centered coords = (coords_centered + center[:, None]).reshape(3, D, H, W) # Apply spatial transform image_out = map_coordinates(image, coords, order=3, mode="nearest") label_out = map_coordinates( label.astype(np.float32), coords, order=0, mode="nearest" ).astype(np.int32) # Flipping if self.flipping and np.random.random() > 0.5: image_out = np.flip(image_out, axis=2).copy() # flip W axis (L/R) label_out = np.flip(label_out, axis=2).copy() label_out = self._remap_lr_labels(label_out) return image_out.astype(np.float32), label_out @staticmethod def _remap_lr_labels(label: np.ndarray) -> np.ndarray: """Swap left/right FreeSurfer labels after L/R flip.""" from nobrainer.data.tissue_classes import FREESURFER_LR_PAIRS result = label.copy() for left, right in FREESURFER_LR_PAIRS: left_mask = label == left right_mask = label == right result[left_mask] = right result[right_mask] = left return result # ------------------------------------------------------------------ # Resolution randomization # ------------------------------------------------------------------ def _randomize_resolution(self, image: np.ndarray) -> np.ndarray: """Simulate variable MRI acquisition resolution.""" from scipy.ndimage import gaussian_filter, zoom rng = np.random.default_rng() target_res = rng.uniform(*self.resolution_range, size=3) # Downsample with anti-aliasing sigmas = [max(0, (r - 1) / 2) for r in target_res] blurred = gaussian_filter(image, sigma=sigmas) # Downsample then upsample down_factors = [1.0 / r for r in target_res] downsampled = zoom(blurred, down_factors, order=1) up_factors = [image.shape[i] / downsampled.shape[i] for i in range(3)] upsampled = zoom(downsampled, up_factors, order=1) # Ensure exact shape match D, H, W = image.shape return upsampled[:D, :H, :W].astype(np.float32) # ------------------------------------------------------------------ # Bias field # ------------------------------------------------------------------ def _add_bias_field(self, image: np.ndarray) -> np.ndarray: """Apply smooth multiplicative bias field.""" D, H, W = image.shape order = 3 coords_d = np.linspace(-1, 1, D) coords_h = np.linspace(-1, 1, H) coords_w = np.linspace(-1, 1, W) rng = np.random.default_rng() coeffs = rng.normal(0, self.bias_field_std, (order + 1, order + 1, order + 1)) bias = np.zeros_like(image) for i in range(order + 1): for j in range(order + 1): for k in range(order + 1): term = coeffs[i, j, k] term = term * np.power(coords_d, i)[:, None, None] term = term * np.power(coords_h, j)[None, :, None] term = term * np.power(coords_w, k)[None, None, :] bias += term bias = np.exp(bias) return (image * bias).astype(np.float32) # ------------------------------------------------------------------ # Rotation matrix helpers # ------------------------------------------------------------------ def _rot_x(angle: float) -> np.ndarray: c, s = np.cos(angle), np.sin(angle) return np.array([[1, 0, 0], [0, c, -s], [0, s, c]]) def _rot_y(angle: float) -> np.ndarray: c, s = np.cos(angle), np.sin(angle) return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]) def _rot_z(angle: float) -> np.ndarray: c, s = np.cos(angle), np.sin(angle) return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) ================================================ FILE: nobrainer/augmentation/transforms.py ================================================ """Augmentation tagging for MONAI transform pipelines. Extends MONAI's ``Compose`` so individual transforms can be tagged as **augmentation** (train-only) or **preprocessing** (always runs). During inference/prediction, augmentation-tagged transforms are automatically skipped. Usage:: from nobrainer.augmentation.transforms import Augmentation, TrainableCompose from monai.transforms import RandAffined, RandGaussianNoised, LoadImaged pipeline = TrainableCompose([ LoadImaged(keys=["image", "label"]), # preprocessing Augmentation(RandAffined(keys=["image", "label"], ...)), # train-only Augmentation(RandGaussianNoised(keys=["image"], ...)), # train-only ]) # Training: all transforms run result = pipeline(data, mode="train") # Predict: augmentation transforms are skipped result = pipeline(data, mode="predict") """ from __future__ import annotations from typing import Any from monai.transforms import Compose class Augmentation: """Wrapper that tags a MONAI transform as train-only (augmentation). When used inside a :class:`TrainableCompose`, this transform is automatically skipped when ``mode="predict"``. Parameters ---------- transform : callable Any MONAI dictionary transform. """ is_augmentation = True def __init__(self, transform: Any) -> None: self.transform = transform def __call__(self, data: Any) -> Any: return self.transform(data) def __repr__(self) -> str: return f"Augmentation({self.transform!r})" class TrainableCompose(Compose): """MONAI Compose that skips augmentation-tagged transforms in predict mode. Behaves identically to ``monai.transforms.Compose`` in train mode. In predict mode, any transform wrapped with :class:`Augmentation` (or having ``is_augmentation = True``) is skipped. Parameters ---------- transforms : list List of MONAI transforms, optionally wrapped with :class:`Augmentation`. mode : str Default mode: ``"train"`` or ``"predict"``. Can be overridden per-call via ``__call__(data, mode=...)``. """ def __init__(self, transforms: list, mode: str = "train") -> None: super().__init__(transforms) self._mode = mode @property def mode(self) -> str: return self._mode @mode.setter def mode(self, value: str) -> None: if value not in ("train", "predict"): raise ValueError(f"mode must be 'train' or 'predict', got '{value}'") self._mode = value def __call__(self, data: Any, mode: str | None = None, **kwargs) -> Any: """Apply transforms, skipping augmentation in predict mode. Extra keyword arguments (e.g., ``end``, ``threading``) are passed through to MONAI's ``Compose.__call__`` for CacheDataset compat. """ active_mode = mode or self._mode if active_mode == "train": # All transforms run — pass through MONAI kwargs return super().__call__(data, **kwargs) # Predict mode: skip augmentation transforms result = data for t in self.transforms: if getattr(t, "is_augmentation", False): continue result = t(result) return result ================================================ FILE: nobrainer/cli/__init__.py ================================================ ================================================ FILE: nobrainer/cli/main.py ================================================ """Main command-line interface for nobrainer.""" from __future__ import annotations import datetime import os import platform import sys import click import nibabel as nib import numpy as np import torch from .. import __version__ from ..prediction import predict as _predict from ..training import get_device _option_kwds = {"show_default": True} class JSONParamType(click.ParamType): name = "json" def convert(self, value, param, ctx): try: import json return json.loads(value) except Exception: self.fail(f"{value} is not valid JSON", param, ctx) @click.group() @click.version_option(__version__, message="%(prog)s version %(version)s") def cli(): """A framework for developing neural network models for 3D image processing.""" return @cli.command() @click.argument("infile") @click.argument("outfile") @click.option( "-m", "--model", type=click.Path(exists=True), required=True, help="Path to PyTorch model file (.pth) or model name.", **_option_kwds, ) @click.option( "--model-type", default="unet", help=( "Model architecture: unet, vnet, attention_unet, unetr, meshnet, " "highresnet, bayesian_vnet, bayesian_meshnet." ), **_option_kwds, ) @click.option( "--n-classes", type=int, default=1, help="Number of output classes.", **_option_kwds, ) @click.option( "--in-channels", type=int, default=1, help="Number of input channels.", **_option_kwds, ) @click.option( "-b", "--block-shape", default=(128, 128, 128), type=int, nargs=3, help="Shape of sub-volumes on which to predict.", **_option_kwds, ) @click.option( "--batch-size", type=int, default=4, help="Number of blocks to process per forward pass.", **_option_kwds, ) @click.option( "--n-samples", type=int, default=1, help="Monte-Carlo samples for Bayesian uncertainty estimation (>1 enables MC-Dropout).", **_option_kwds, ) @click.option( "--device", default="auto", help='Compute device: "auto", "cpu", "cuda", "cuda:0", …', **_option_kwds, ) @click.option( "-v", "--verbose", is_flag=True, help="Print progress messages.", **_option_kwds ) def predict( *, infile, outfile, model, model_type, n_classes, in_channels, block_shape, batch_size, n_samples, device, verbose, ): """Predict labels from a NIfTI volume using a trained PyTorch model. The predictions are saved to OUTFILE. """ if os.path.exists(outfile): raise FileExistsError(f"Output file already exists: {outfile}") # Resolve device if device == "auto": _device = get_device() else: _device = torch.device(device) if verbose: click.echo(f"Using device: {_device}") # Load model architecture + weights from ..models import get as _get_model try: factory = _get_model(model_type) pt_model = factory(n_classes=n_classes, in_channels=in_channels) state = torch.load(model, map_location=_device, weights_only=True) pt_model.load_state_dict(state, strict=False) except Exception as exc: click.echo(click.style(f"ERROR: could not load model: {exc}", fg="red")) raise SystemExit(1) from exc if verbose: click.echo("Running prediction ...") if n_samples > 1: from ..prediction import predict_with_uncertainty try: label_img, var_img, entropy_img = predict_with_uncertainty( infile, pt_model, n_samples=n_samples, block_shape=block_shape, batch_size=batch_size, device=_device, ) nib.save(label_img, outfile) nib.save(var_img, outfile.replace(".nii", "_var.nii")) nib.save(entropy_img, outfile.replace(".nii", "_entropy.nii")) except NotImplementedError: click.echo( click.style( "predict_with_uncertainty not yet implemented; " "falling back to deterministic predict()", fg="yellow", ) ) out_img = _predict( infile, pt_model, block_shape=block_shape, batch_size=batch_size, device=_device, ) nib.save(out_img, outfile) else: out_img = _predict( infile, pt_model, block_shape=block_shape, batch_size=batch_size, device=_device, ) nib.save(out_img, outfile) if verbose: click.echo(click.style(f"Output saved to {outfile}", fg="green")) @cli.command() @click.option( "-i", "--input", "input_paths", multiple=True, type=click.Path(exists=True), required=True, help="TFRecord file(s) to convert.", **_option_kwds, ) @click.option( "-o", "--output-dir", required=True, type=click.Path(), help="Output directory for NIfTI or HDF5 files.", **_option_kwds, ) @click.option( "--format", "output_format", default="nifti", type=click.Choice(["nifti", "hdf5"]), help="Output format.", **_option_kwds, ) @click.option( "-v", "--verbose", is_flag=True, help="Print progress messages.", **_option_kwds ) def convert_tfrecords(*, input_paths, output_dir, output_format, verbose): """Convert TFRecord files to NIfTI or HDF5 (no TensorFlow required).""" from ..io import convert_tfrecords as _convert if verbose: click.echo(f"Converting {len(input_paths)} TFRecord file(s) …") out_paths = _convert( tfrecord_paths=list(input_paths), output_dir=output_dir, output_format=output_format, ) if verbose: for p in out_paths: click.echo(f" → {p}") click.echo(click.style(f"Done. {len(out_paths)} files written.", fg="green")) @cli.command() @click.argument("output", type=click.Path()) @click.option( "-i", "--images", multiple=True, type=click.Path(exists=True), required=True, help="Image NIfTI files.", **_option_kwds, ) @click.option( "-l", "--labels", multiple=True, type=click.Path(exists=True), required=True, help="Label NIfTI files (same order as --images).", **_option_kwds, ) @click.option( "--chunk-shape", default="32,32,32", help="Chunk shape (comma-separated).", **_option_kwds, ) @click.option("--no-conform", is_flag=True, help="Disable auto-conforming.") @click.option("-v", "--verbose", is_flag=True, help="Print progress.") def convert_to_zarr(*, output, images, labels, chunk_shape, no_conform, verbose): """Convert NIfTI image+label pairs to a sharded Zarr3 store.""" from ..datasets.zarr_store import create_zarr_store if len(images) != len(labels): click.echo( click.style( f"Error: {len(images)} images but {len(labels)} labels.", fg="red" ) ) sys.exit(1) pairs = list(zip(images, labels)) chunks = tuple(int(x) for x in chunk_shape.split(",")) if verbose: click.echo(f"Converting {len(pairs)} pairs → {output}") store_path = create_zarr_store( pairs, output, chunk_shape=chunks, conform=not no_conform, ) click.echo(click.style(f"Zarr store created: {store_path}", fg="green")) @cli.command() def merge(): """Merge multiple models trained with variational weights.""" click.echo("Not implemented yet.") sys.exit(-2) @cli.command() @click.argument("outfile") @click.option( "-m", "--model", type=click.Path(exists=True), required=True, help="Path to model checkpoint (.ckpt) or weights (.pth).", **_option_kwds, ) @click.option( "--model-type", default="progressivegan", type=click.Choice(["progressivegan", "dcgan"]), help="Generative model architecture.", **_option_kwds, ) @click.option( "--latent-size", type=int, default=512, help="Latent vector dimension.", **_option_kwds, ) @click.option( "--n-samples", type=int, default=1, help="Number of images to generate.", **_option_kwds, ) @click.option( "--device", default="auto", help='Compute device: "auto", "cpu", "cuda", …', **_option_kwds, ) @click.option( "-v", "--verbose", is_flag=True, help="Print progress messages.", **_option_kwds ) def generate( *, outfile, model, model_type, latent_size, n_samples, device, verbose, ): """Generate brain volumes from a trained GAN model. Saves OUTFILE (NIfTI) for each generated sample. When ``--n-samples > 1`` the file stem is suffixed with ``_0``, ``_1``, … before the extension. """ import os if device == "auto": _device = get_device() else: _device = torch.device(device) if verbose: click.echo(f"Using device: {_device}") from ..models import get as _get_model try: factory = _get_model(model_type) pt_model = factory(latent_size=latent_size) # Support both .ckpt (Lightning) and .pth (state dict) if model.endswith(".ckpt"): model_cls = type(pt_model) pt_model = model_cls.load_from_checkpoint(model, map_location=_device) else: state = torch.load(model, map_location=_device, weights_only=True) pt_model.load_state_dict(state, strict=False) except Exception as exc: click.echo(click.style(f"ERROR: could not load model: {exc}", fg="red")) raise SystemExit(1) from exc pt_model = pt_model.to(_device) pt_model.eval() if verbose: click.echo(f"Generating {n_samples} sample(s) …") stem, ext = os.path.splitext(outfile) if ext == ".gz": stem, ext2 = os.path.splitext(stem) ext = ext2 + ext with torch.no_grad(): for i in range(n_samples): z = torch.randn(1, latent_size, device=_device) out = pt_model.generator(z) # (1, 1, D, H, W) arr = out.squeeze().cpu().numpy() img = nib.Nifti1Image(arr.astype(np.float32), np.eye(4)) path = f"{stem}_{i}{ext}" if n_samples > 1 else outfile nib.save(img, path) if verbose: click.echo(f" Saved {path}") if verbose: click.echo(click.style("Done.", fg="green")) @cli.command() @click.option( "--working-dir", required=True, type=click.Path(), help="Directory with train script and data_manifest.json.", **_option_kwds, ) @click.option( "--model-family", default="bayesian_vnet", help="Model family to use for training.", **_option_kwds, ) @click.option( "--max-experiments", type=int, default=10, help="Maximum number of experiments.", **_option_kwds, ) @click.option( "--budget-hours", type=float, default=8.0, help="Wall-clock budget in hours.", **_option_kwds, ) @click.option( "--budget-minutes", type=float, default=None, help="Wall-clock budget in minutes (overrides --budget-hours).", **_option_kwds, ) @click.option( "-v", "--verbose", is_flag=True, help="Print per-experiment progress.", **_option_kwds, ) def research( *, working_dir, model_family, max_experiments, budget_hours, budget_minutes, verbose, ): """Run the autoresearch experiment loop. Proposes hyperparameter configs (via Anthropic API or random grid), runs training experiments, and keeps improvements. Writes ``run_summary.md`` in WORKING_DIR on completion. """ from ..research.loop import run_loop if verbose: import logging logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") budget_seconds = None if budget_minutes is not None: budget_seconds = budget_minutes * 60 results = run_loop( working_dir=working_dir, model_family=model_family, max_experiments=max_experiments, budget_hours=budget_hours, budget_seconds=budget_seconds, ) # Progress table click.echo( f"\n{'run_id':>6} {'val_dice':>10} {'outcome':<12} {'failure_reason'}" ) click.echo("-" * 55) for r in results: dice_str = f"{r.val_dice:.4f}" if r.val_dice is not None else "—" click.echo( f"{r.run_id:>6} {dice_str:>10} {r.outcome:<12} {r.failure_reason or '—'}" ) summary_path = click.format_filename(f"{working_dir}/run_summary.md") click.echo(click.style(f"\nSummary written to {summary_path}", fg="green")) @cli.command() @click.option( "--model-path", required=True, type=click.Path(exists=True), help="Path to best_model.pth file.", **_option_kwds, ) @click.option( "--config-path", required=True, type=click.Path(exists=True), help="Path to best_config.json file.", **_option_kwds, ) @click.option( "--trained-models-path", required=True, type=click.Path(), help="Root of the DataLad-managed trained_models dataset.", **_option_kwds, ) @click.option( "--model-family", default="bayesian_vnet", help="Model family name (used as subdirectory).", **_option_kwds, ) @click.option( "--val-dice", type=float, required=True, help="Validation Dice score of the best model.", **_option_kwds, ) @click.option( "--source-run-id", default="", help="Run ID string for traceability.", **_option_kwds, ) def commit( *, model_path, config_path, trained_models_path, model_family, val_dice, source_run_id, ): """Version the best model with DataLad and push to OSF. Copies model weights and config into the trained_models DataLad dataset, generates a model card, saves with DataLad, and pushes to OSF. """ from ..research.loop import commit_best_model try: result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models_path, model_family=model_family, val_dice=val_dice, source_run_id=source_run_id, ) except ImportError as exc: click.echo(click.style(f"ERROR: {exc}", fg="red")) raise SystemExit(1) from exc click.echo(f"Model versioned at: {result['path']}") click.echo(f"DataLad commit: {result['datalad_commit']}") if result.get("osf_url"): click.echo(click.style(f"OSF URL: {result['osf_url']}", fg="green")) else: click.echo(click.style("OSF push skipped (no remote configured)", fg="yellow")) @cli.command() def save(): """Save a model to PyTorch format.""" click.echo("Not implemented yet.") sys.exit(-2) @cli.command() def evaluate(): """Evaluate a model's predictions against known labels.""" click.echo("Not implemented yet.") sys.exit(-2) @cli.command() def info(): """Return information about this system.""" uname = platform.uname() cuda_available = torch.cuda.is_available() cuda_devices = torch.cuda.device_count() if cuda_available else 0 s = f"""\ Python: Version: {platform.python_version()} Implementation: {platform.python_implementation()} 64-bit: {sys.maxsize > 2**32} Packages: Nobrainer: {__version__} Nibabel: {nib.__version__} Numpy: {np.__version__} PyTorch: {torch.__version__} CUDA available: {cuda_available} CUDA devices: {cuda_devices} System: OSType: {uname.system} Release: {uname.release} Version: {uname.version} Architecture: {uname.machine} Timestamp: {datetime.datetime.utcnow().strftime('%Y/%m/%d %T')}""" click.echo(s) # For debugging only. if __name__ == "__main__": cli() ================================================ FILE: nobrainer/cli/tests/__init__.py ================================================ ================================================ FILE: nobrainer/cli/tests/main_test.py ================================================ """Tests for `nobrainer.cli.main`.""" import csv from pathlib import Path from click.testing import CliRunner import nibabel as nib import numpy as np import pytest from .. import main as climain from ...io import read_csv from ...models.meshnet import meshnet from ...models.progressivegan import progressivegan from ...utils import get_data def test_convert_nonscalar_labels(tmp_path): runner = CliRunner() with runner.isolated_filesystem(): csvpath = get_data(tmp_path) tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") tfrecords_template.parent.mkdir(exist_ok=True) args = """\ convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 --examples-per-shard=2 --to-ras --no-verify-volumes """.format( csvpath, tfrecords_template ) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 assert Path("data/shard-000.tfrecords").is_file() assert Path("data/shard-001.tfrecords").is_file() assert Path("data/shard-002.tfrecords").is_file() assert Path("data/shard-003.tfrecords").is_file() assert Path("data/shard-004.tfrecords").is_file() assert not Path("data/shard-005.tfrecords").is_file() def test_convert_scalar_int_labels(tmp_path): runner = CliRunner() with runner.isolated_filesystem(): csvpath = get_data(str(tmp_path)) # Make labels scalars. data = [(x, 0) for (x, _) in read_csv(csvpath)] csvpath = tmp_path.with_suffix(".new.csv") with open(csvpath, "w", newline="") as myfile: wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) wr.writerows(data) tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") tfrecords_template.parent.mkdir(exist_ok=True) args = """\ convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 --examples-per-shard=2 --to-ras --no-verify-volumes """.format( csvpath, tfrecords_template ) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 assert Path("data/shard-000.tfrecords").is_file() assert Path("data/shard-001.tfrecords").is_file() assert Path("data/shard-002.tfrecords").is_file() assert Path("data/shard-003.tfrecords").is_file() assert Path("data/shard-004.tfrecords").is_file() assert not Path("data/shard-005.tfrecords").is_file() def test_convert_scalar_float_labels(tmp_path): runner = CliRunner() with runner.isolated_filesystem(): csvpath = get_data(str(tmp_path)) # Make labels scalars. data = [(x, 1.0) for (x, _) in read_csv(csvpath)] csvpath = tmp_path.with_suffix(".new.csv") with open(csvpath, "w", newline="") as myfile: wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) wr.writerows(data) tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") tfrecords_template.parent.mkdir(exist_ok=True) args = """\ convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 --examples-per-shard=2 --to-ras --no-verify-volumes """.format( csvpath, tfrecords_template ) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 assert Path("data/shard-000.tfrecords").is_file() assert Path("data/shard-001.tfrecords").is_file() assert Path("data/shard-002.tfrecords").is_file() assert Path("data/shard-003.tfrecords").is_file() assert Path("data/shard-004.tfrecords").is_file() assert not Path("data/shard-005.tfrecords").is_file() def test_convert_multi_resolution(tmp_path): runner = CliRunner() with runner.isolated_filesystem(): csvpath = get_data(str(tmp_path)) # Make labels scalars. data = [(x, 1.0) for (x, _) in read_csv(csvpath)] csvpath = tmp_path.with_suffix(".new.csv") with open(csvpath, "w", newline="") as myfile: wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) wr.writerows(data) tfrecords_template = Path("data/shard-{shard:03d}.tfrecords") tfrecords_template.parent.mkdir(exist_ok=True) args = """\ convert --csv={} --tfrecords-template={} --volume-shape 256 256 256 --start-resolution 64 --examples-per-shard=2 --no-verify-volumes --multi-resolution """.format( csvpath, tfrecords_template ) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 resolutions = [64, 128, 256] for res in resolutions: assert Path("data/shard-000-res-{:03d}.tfrecords".format(res)).is_file() assert Path("data/shard-001-res-{:03d}.tfrecords".format(res)).is_file() assert Path("data/shard-002-res-{:03d}.tfrecords".format(res)).is_file() assert Path("data/shard-003-res-{:03d}.tfrecords".format(res)).is_file() assert Path("data/shard-004-res-{:03d}.tfrecords".format(res)).is_file() assert not Path("data/shard-005-res-{:03d}.tfrecords".format(res)).is_file() @pytest.mark.xfail def test_merge(): assert False def test_predict(): runner = CliRunner() with runner.isolated_filesystem(): model = meshnet(1, (10, 10, 10, 1)) model_path = "model.h5" model.save(model_path) img_path = "features.nii.gz" nib.Nifti1Image(np.random.randn(20, 20, 20), np.eye(4)).to_filename(img_path) out_path = "predictions.nii.gz" args = """\ predict --model={} --block-shape 10 10 10 --resize-features-to 20 20 20 --largest-label --rotate-and-predict {} {} """.format( model_path, img_path, out_path ) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 assert Path("predictions.nii.gz").is_file() assert nib.load(out_path).shape == (20, 20, 20) def test_generate(): runner = CliRunner() with runner.isolated_filesystem(): generator, _ = progressivegan( latent_size=256, g_fmap_base=1024, d_fmap_base=1024 ) resolutions = [8, 16] Path("models").mkdir(exist_ok=True) for res in resolutions: generator.add_resolution() generator([np.random.random((1, 256)), 1.0]) # to build the model by a call model_path = "models/generator_res_{}".format(res) generator.save(model_path) assert Path(model_path).is_dir() out_path = "generated.nii.gz" args = """\ generate --model {} --multi-resolution --latent-size 256 {} """.format( "models", out_path ) result = runner.invoke(climain.cli, args.split()) assert result.exit_code == 0 for res in resolutions: assert Path("generated_res_{}.nii.gz".format(res)).is_file() assert nib.load("generated_res_{}.nii.gz".format(res)).shape == ( res, res, res, ) @pytest.mark.xfail def test_save(): assert False @pytest.mark.xfail def test_evaluate(): assert False def test_info(): runner = CliRunner() result = runner.invoke(climain.cli, ["info"]) assert result.exit_code == 0 assert "Python" in result.output assert "System" in result.output assert "Timestamp" in result.output ================================================ FILE: nobrainer/dataset.py ================================================ """PyTorch dataset utilities backed by MONAI.""" from __future__ import annotations from pathlib import Path from typing import Any, Callable from monai.data import CacheDataset, DataLoader from monai.transforms import ( EnsureChannelFirstd, LoadImaged, NormalizeIntensityd, Orientationd, Spacingd, ) import numpy as np import torch def get_dataset( image_paths: list[str | Path], label_paths: list[str | Path] | None = None, block_shape: tuple[int, int, int] | None = None, batch_size: int = 1, num_workers: int = 0, augment: bool = False, binarize_labels: bool | set | Callable = False, target_spacing: tuple[float, float, float] = (1.0, 1.0, 1.0), cache_rate: float = 1.0, **kwargs: Any, ) -> DataLoader: """Build a MONAI-backed :class:`torch.utils.data.DataLoader`. Applies the following transform chain: ``LoadImaged → EnsureChannelFirstd → Orientationd("RAS") → Spacingd(*target_spacing) → NormalizeIntensityd`` → (if augment) ``RandAffined, RandFlipd, RandGaussianNoised`` Parameters ---------- image_paths : list Paths to input NIfTI volumes. label_paths : list or None Paths to corresponding label NIfTI volumes. ``None`` for inference-only datasets. block_shape : tuple or None If provided, spatial patch size ``(D, H, W)`` extracted by MONAI's ``RandSpatialCropd``. ``None`` loads full volumes. batch_size : int Number of samples per mini-batch. num_workers : int Number of DataLoader worker processes. augment : bool Whether to apply random spatial and intensity augmentations. target_spacing : tuple of float Voxel spacing (mm) to resample volumes to. cache_rate : float Fraction of dataset to cache in memory (1.0 = all). **kwargs Additional keyword arguments forwarded to :class:`DataLoader`. Returns ------- DataLoader PyTorch DataLoader that yields batches of ``{"image": tensor}`` (or ``{"image": tensor, "label": tensor}`` when labels are given). """ if label_paths is not None and len(image_paths) != len(label_paths): raise ValueError( f"len(image_paths)={len(image_paths)} != len(label_paths)={len(label_paths)}" ) has_labels = label_paths is not None # Build data dicts if has_labels: data = [ {"image": str(img), "label": str(lbl)} for img, lbl in zip(image_paths, label_paths) ] keys = ["image", "label"] else: data = [{"image": str(img)} for img in image_paths] keys = ["image"] # Core transforms — use NibabelReader to support .mgz and other formats transforms: list[Any] = [ LoadImaged(keys=keys, image_only=False, reader="NibabelReader"), EnsureChannelFirstd(keys=keys), Orientationd(keys=keys, axcodes="RAS"), Spacingd( keys=keys, pixdim=target_spacing, mode=["bilinear", "nearest"] if has_labels else ["bilinear"], ), NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True), ] # Optional label binarization (e.g., FreeSurfer parcellation → brain mask) if binarize_labels and has_labels: from monai.transforms import Lambdad if callable(binarize_labels) and binarize_labels is not True: transforms.append(Lambdad(keys=["label"], func=binarize_labels)) elif isinstance(binarize_labels, set): label_set = binarize_labels def _remap(x): import torch mask = torch.zeros_like(x) for val in label_set: mask = mask | (x == val) return mask.float() transforms.append(Lambdad(keys=["label"], func=_remap)) else: transforms.append(Lambdad(keys=["label"], func=lambda x: (x > 0).float())) # Optional augmentation — supports bool or profile name if augment: from nobrainer.augmentation.profiles import get_augmentation_profile profile_name = augment if isinstance(augment, str) else "standard" aug_transforms = get_augmentation_profile(profile_name, keys=keys) transforms += aug_transforms if block_shape is not None: from monai.transforms import RandSpatialCropd transforms.append( RandSpatialCropd(keys=keys, roi_size=block_shape, random_size=False) ) # Use TrainableCompose so augmentation can be skipped during predict from nobrainer.augmentation.transforms import TrainableCompose compose = TrainableCompose(transforms) dataset = CacheDataset( data=data, transform=compose, cache_rate=cache_rate, num_workers=max(0, num_workers), ) return DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available(), **kwargs, ) # --------------------------------------------------------------------------- # Zarr v3 dataset (requires [zarr] extras) # --------------------------------------------------------------------------- class ZarrDataset(torch.utils.data.Dataset): """PyTorch Dataset backed by Zarr v3 stores. Each item in *data_list* is a dict with ``"image"`` (and optionally ``"label"``) keys pointing to ``.zarr`` store paths. """ def __init__( self, data_list: list[dict[str, str]], transform: Any | None = None, zarr_level: int = 0, ): self.data = data_list self.transform = transform self.level = zarr_level def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> dict: import zarr item = self.data[idx] store = zarr.open_group(str(item["image"]), mode="r") img_arr = np.asarray(store[str(self.level)]).astype(np.float32) result: dict[str, Any] = {"image": img_arr[None]} # add channel dim if "label" in item: lbl_store = zarr.open_group(str(item["label"]), mode="r") lbl_arr = np.asarray(lbl_store[str(self.level)]).astype(np.float32) result["label"] = lbl_arr[None] if self.transform is not None: result = self.transform(result) # Convert to tensors if still numpy for k, v in result.items(): if isinstance(v, np.ndarray): result[k] = torch.from_numpy(v) return result def _is_zarr_path(path: str | Path) -> bool: """Check if a path looks like a Zarr store.""" return str(path).rstrip("/").endswith(".zarr") def _get_zarr_dataset( data: list[dict[str, str]], batch_size: int, num_workers: int, augment: bool, zarr_level: int, **kwargs: Any, ) -> DataLoader: """Build a DataLoader from Zarr v3 stores.""" transform = None if augment: import monai.transforms as mt transform = mt.Compose( [ mt.RandAffined( keys=list(data[0].keys()), prob=0.5, rotate_range=(0.1, 0.1, 0.1), ), mt.RandFlipd(keys=list(data[0].keys()), prob=0.5), ] ) dataset = ZarrDataset(data, transform=transform, zarr_level=zarr_level) return DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available(), **kwargs, ) __all__ = ["get_dataset", "ZarrDataset"] ================================================ FILE: nobrainer/datasets/__init__.py ================================================ """Dataset fetching utilities for various neuroimaging sources. Each submodule provides functions to install and fetch data from a specific source. All require the ``[versioning]`` optional extra (``datalad``, ``git-annex``) unless noted otherwise. Available sources ----------------- - :mod:`nobrainer.datasets.openneuro` — OpenNeuro raw + derivatives """ from __future__ import annotations def _check_datalad(): """Import datalad.api, raising a clear error if not available.""" try: import datalad.api as dl return dl except ImportError: raise ImportError( "DataLad is required for dataset fetching. " "Install with: pip install 'nobrainer[versioning]'\n" "Also install git-annex: uv tool install git-annex" ) from None __all__ = ["openneuro"] ================================================ FILE: nobrainer/datasets/openneuro.py ================================================ """Fetch datasets from OpenNeuro and OpenNeuro Derivatives via DataLad. Requires the ``[versioning]`` extra (``datalad >= 0.19``) and the ``git-annex`` PyPI package (``uv tool install git-annex`` or ``pip install git-annex``). Examples -------- Fetch fmriprep derivatives and get T1w + aparc+aseg pairs:: from nobrainer.datasets.openneuro import ( install_derivatives, find_subject_pairs, write_manifest, ) ds_path = install_derivatives("ds000114", "/tmp/data") pairs = find_subject_pairs(ds_path) write_manifest(pairs, "manifest.csv") Fetch a raw OpenNeuro dataset:: from nobrainer.datasets.openneuro import install_dataset ds_path = install_dataset("ds000114", "/tmp/data") Fetch specific files without auto-discovery:: from nobrainer.datasets.openneuro import ( install_derivatives, glob_dataset, fetch_files, ) ds_path = install_derivatives("ds000114", "/tmp/data") bold_files = glob_dataset(ds_path, "sub-*/func/*_bold.nii.gz") fetched = fetch_files(ds_path, bold_files[:5]) """ from __future__ import annotations import logging from pathlib import Path logger = logging.getLogger(__name__) _OPENNEURO_GH = "https://github.com/OpenNeuroDatasets" _OPENNEURO_DERIV_GH = "https://github.com/OpenNeuroDerivatives" def _dl(): """Lazy import of datalad.api.""" from nobrainer.datasets import _check_datalad return _check_datalad() # --------------------------------------------------------------------------- # Install (lightweight clone, no bulk download) # --------------------------------------------------------------------------- def install_dataset( dataset_id: str, path: str | Path, ) -> Path: """Clone an OpenNeuro dataset (metadata only, no file content). Parameters ---------- dataset_id : str OpenNeuro accession (e.g. ``"ds000114"``). path : str or Path Base directory. The dataset is cloned into ``/``. Returns ------- Path Absolute path to the installed dataset directory. """ dl = _dl() dest = Path(path) / dataset_id if dest.exists(): logger.info("Dataset %s already at %s", dataset_id, dest) return dest.resolve() source = f"{_OPENNEURO_GH}/{dataset_id}.git" logger.info("Installing %s from %s", dataset_id, source) dl.install(source=source, path=str(dest)) return dest.resolve() def install_derivatives( dataset_id: str, path: str | Path, derivative: str = "fmriprep", ) -> Path: """Clone an OpenNeuro Derivatives dataset (metadata only). Parameters ---------- dataset_id : str OpenNeuro accession (e.g. ``"ds000114"``). path : str or Path Base directory. Cloned into ``/-``. derivative : str Pipeline name (default ``"fmriprep"``). Common values: ``"fmriprep"``, ``"mriqc"``, ``"freesurfer"``. Returns ------- Path Absolute path to the installed derivative directory. """ dl = _dl() dest = Path(path) / f"{dataset_id}-{derivative}" if dest.exists(): logger.info("Derivative %s-%s already at %s", dataset_id, derivative, dest) return dest.resolve() source = f"{_OPENNEURO_DERIV_GH}/{dataset_id}-{derivative}.git" logger.info("Installing %s-%s from %s", dataset_id, derivative, source) dl.install(source=source, path=str(dest)) return dest.resolve() # --------------------------------------------------------------------------- # File discovery and download # --------------------------------------------------------------------------- def glob_dataset( dataset_dir: str | Path, pattern: str, ) -> list[Path]: """Glob a DataLad dataset directory (metadata only, no download). Works on the git tree — returned paths may be git-annex symlinks whose content hasn't been fetched yet. Parameters ---------- dataset_dir : str or Path Root of the DataLad dataset. pattern : str Glob pattern (e.g. ``"sub-*/anat/*_T1w.nii.gz"``). Returns ------- list of Path Sorted matching paths. """ return sorted(Path(dataset_dir).glob(pattern)) def fetch_files( dataset_dir: str | Path, paths: list[str | Path], ) -> list[Path]: """Download specific files from a DataLad dataset. Parameters ---------- dataset_dir : str or Path Root of the DataLad dataset. paths : list of str or Path Files to download (absolute or relative to *dataset_dir*). Returns ------- list of Path Paths whose content was successfully downloaded. """ dl = _dl() dataset_dir = Path(dataset_dir) try: dl.get([str(p) for p in paths], dataset=str(dataset_dir)) except Exception as exc: logger.warning("datalad get failed: %s", exc) return [p for p in (Path(x) for x in paths) if _file_ok(p)] # --------------------------------------------------------------------------- # Paired file discovery (structural MRI) # --------------------------------------------------------------------------- def _extract_subject_id(path: Path) -> str: """Extract ``sub-XX`` from a BIDS-style path. Checks directory components first (``sub-01/anat/...``), then parses the filename (``sub-01_desc-preproc_T1w.nii.gz``). """ # Check directory parts (e.g. .../sub-01/anat/...) for part in path.parts[:-1]: # skip filename if part.startswith("sub-"): return part # Parse from filename name = path.name if name.startswith("sub-"): return name.split("_")[0] return name def _file_ok(p: Path) -> bool: """True if *p* is a real file with nonzero size.""" try: return p.stat().st_size > 0 except OSError: return False def find_subject_pairs( dataset_dir: str | Path, feature_pattern: str | None = None, label_pattern: str | None = None, native_space: bool = True, download: bool = True, ) -> list[dict[str, str]]: """Discover and optionally download paired (feature, label) files. The default patterns find native-space preprocessed T1w images and aparc+aseg parcellations from fmriprep derivatives. Strategy: 1. Glob the dataset tree (git metadata only) to find label files. 2. For each label, find the matching feature file for the same subject. 3. Download each pair via ``datalad get``. 4. Verify both files are accessible before including them. Parameters ---------- dataset_dir : str or Path Root of a DataLad dataset (typically an fmriprep derivative). feature_pattern : str or None Glob for feature files. When *None*, discovers the best native-space T1w pattern automatically. label_pattern : str or None Glob for label files. When *None*, tries ``*desc-aparcaseg_dseg.nii.gz`` then ``*desc-aseg_dseg.nii.gz``. native_space : bool Prefer native-space files (no ``space-`` token). Default True. download : bool If True (default), download each pair via ``datalad get``. Returns ------- list of dict Each dict: ``{"subject_id", "t1w_path", "label_path"}``. """ dataset_dir = Path(dataset_dir) pairs: list[dict[str, str]] = [] # --- Discover label files --- if label_pattern is not None: label_files = glob_dataset(dataset_dir, label_pattern) else: label_files = [] for pat in [ "sub-*/anat/*desc-aparcaseg_dseg.nii.gz", "sub-*/anat/*desc-aseg_dseg.nii.gz", ]: label_files = glob_dataset(dataset_dir, pat) if label_files: logger.info("Found %d labels matching %s", len(label_files), pat) break if not label_files: logger.warning("No label files found in %s", dataset_dir) return pairs # --- Match each label to a feature file --- for label_path in label_files: sub_id = _extract_subject_id(label_path) anat_dir = label_path.parent if feature_pattern is not None: feat_candidates = sorted(anat_dir.glob(feature_pattern)) else: feat_candidates = [ p for p in anat_dir.glob(f"{sub_id}*desc-preproc_T1w.nii.gz") if (not native_space) or ("space-" not in p.name) ] if not feat_candidates: feat_candidates = sorted(anat_dir.glob(f"{sub_id}*_T1w.nii.gz"))[:1] if not feat_candidates: logger.warning("No feature file for %s", sub_id) continue feat_path = feat_candidates[0] if download: logger.info("Downloading pair for %s", sub_id) fetch_files(dataset_dir, [feat_path, label_path]) feat_ok = _file_ok(feat_path) if download else True label_ok = _file_ok(label_path) if download else True if feat_ok and label_ok: pairs.append( { "subject_id": sub_id, "t1w_path": str(feat_path), "label_path": str(label_path), } ) else: logger.warning("Skipping %s: files not accessible", sub_id) logger.info("Found %d paired subjects in %s", len(pairs), dataset_dir.name) return pairs # --------------------------------------------------------------------------- # Manifest writing # --------------------------------------------------------------------------- def write_manifest( pairs: list[dict[str, str]], output_path: str | Path, split_ratios: tuple[int, int, int] = (80, 10, 10), seed: int = 42, ) -> Path: """Write a manifest CSV with train/val/test split. Parameters ---------- pairs : list of dict Each dict has ``"subject_id"``, ``"t1w_path"``, ``"label_path"``. Optionally ``"dataset_id"``. output_path : str or Path Destination CSV. split_ratios : tuple of int (train, val, test) percentages. seed : int Random seed for reproducible splits. Returns ------- Path Written CSV path. """ import csv import numpy as np output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) rng = np.random.default_rng(seed) indices = rng.permutation(len(pairs)) total = sum(split_ratios) n_train = int(len(pairs) * split_ratios[0] / total) n_val = int(len(pairs) * split_ratios[1] / total) for i, idx in enumerate(indices): if i < n_train: pairs[idx]["split"] = "train" elif i < n_train + n_val: pairs[idx]["split"] = "val" else: pairs[idx]["split"] = "test" fieldnames = ["subject_id", "dataset_id", "t1w_path", "label_path", "split"] if not any("dataset_id" in p for p in pairs): fieldnames = [f for f in fieldnames if f != "dataset_id"] with open(output_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(pairs) counts = { s: sum(1 for p in pairs if p.get("split") == s) for s in ("train", "val", "test") } logger.info( "Manifest: %s — %d subjects (train=%d, val=%d, test=%d)", output_path, len(pairs), counts["train"], counts["val"], counts["test"], ) return output_path ================================================ FILE: nobrainer/datasets/zarr_store.py ================================================ """Multi-subject Zarr3 dataset store with sharding. Converts NIfTI collections into a single sharded Zarr3 store where subjects are stacked along a 4th dimension: ``images[N, D, H, W]`` and ``labels[N, D, H, W]``. This layout enables efficient partial I/O for training: reading one subject's patch is a single seek into one shard file. Requires the ``[zarr]`` optional extra (``zarr >= 3.0``). """ from __future__ import annotations import json import logging from pathlib import Path from typing import Any import numpy as np logger = logging.getLogger(__name__) def _conform_volume(img, target_shape, target_voxel_size=(1.0, 1.0, 1.0)): """Conform a nibabel image to target shape and voxel size.""" from nibabel.processing import conform return conform(img, out_shape=target_shape, voxel_size=target_voxel_size) def _infer_target_shape( image_paths: list[str | Path], max_scan: int = 50, ) -> tuple[tuple[int, int, int], tuple[float, float, float]]: """Infer target shape and voxel size from input volumes. Uses the median shape and modal voxel size across a sample of volumes. """ import nibabel as nib shapes = [] voxel_sizes = [] for p in image_paths[:max_scan]: img = nib.load(p) shapes.append(img.shape[:3]) voxel_sizes.append(tuple(np.abs(img.header.get_zooms()[:3]))) # Median shape (rounded to nearest integer) median_shape = tuple(int(np.median([s[i] for s in shapes])) for i in range(3)) # Modal voxel size (most common, or median if all different) from collections import Counter vox_counts = Counter(voxel_sizes) if vox_counts: modal_voxel = vox_counts.most_common(1)[0][0] else: modal_voxel = (1.0, 1.0, 1.0) return median_shape, modal_voxel def create_zarr_store( image_label_pairs: list[tuple[str, str]], output_path: str | Path, subject_ids: list[str] | None = None, chunk_shape: tuple[int, int, int] = (32, 32, 32), shard_shape: tuple[int, int, int] | None = None, compressor: str = "blosc", conform: bool = True, target_shape: tuple[int, int, int] | None = None, target_voxel_size: tuple[float, float, float] | None = None, ) -> Path: """Convert NIfTI pairs into a single sharded Zarr3 store. When ``conform=True`` (default), volumes are conformed to a uniform shape so they can be stacked into 4D arrays ``images[N, D, H, W]`` and ``labels[N, D, H, W]``. The target shape is inferred from the data (median shape) unless explicitly provided. Parameters ---------- image_label_pairs : list of (str, str) List of ``(image_path, label_path)`` tuples. output_path : str or Path Output Zarr store directory. subject_ids : list of str or None Subject identifiers. If None, auto-generated as ``sub-000``, etc. chunk_shape : tuple of int Spatial chunk dimensions (default 32³). shard_shape : tuple of int or None Shard dimensions. None = auto (full array or large multiple). compressor : str Compression codec name (default ``"blosc"``). conform : bool Auto-conform volumes to uniform shape (default True). target_shape : tuple of int or None Target spatial shape. None = infer from data. target_voxel_size : tuple of float or None Target voxel size. None = infer from data. Returns ------- Path Path to the created Zarr store. """ import nibabel as nib import zarr output_path = Path(output_path) n_subjects = len(image_label_pairs) if subject_ids is None: subject_ids = [f"sub-{i:03d}" for i in range(n_subjects)] if len(subject_ids) != n_subjects: raise ValueError( f"subject_ids length ({len(subject_ids)}) != pairs ({n_subjects})" ) image_paths = [p[0] for p in image_label_pairs] # Infer or validate target shape if conform: if target_shape is None or target_voxel_size is None: inferred_shape, inferred_voxel = _infer_target_shape(image_paths) if target_shape is None: target_shape = inferred_shape if target_voxel_size is None: target_voxel_size = inferred_voxel logger.info( "Inferred target: shape=%s, voxel_size=%s", target_shape, target_voxel_size, ) else: # Check all shapes are the same first_img = nib.load(image_paths[0]) target_shape = first_img.shape[:3] for p in image_paths[1:]: img = nib.load(p) if img.shape[:3] != target_shape: raise ValueError( f"Non-uniform shapes detected ({img.shape[:3]} vs {target_shape}). " "Use conform=True to auto-conform, or ensure all volumes match." ) D, H, W = target_shape full_chunk = (1, *chunk_shape) # one subject per chunk along axis 0 # Shard shape: group subjects into shards for balanced write parallelism # and read efficiency. Default: ~50 subjects per shard → manageable # file count while allowing parallel writes across shards. subjects_per_shard = 50 if shard_shape is not None: full_shard = shard_shape else: full_shard = (min(subjects_per_shard, n_subjects), D, H, W) # Create store store = zarr.open_group(str(output_path), mode="w") # Create sharded 4D arrays n_shards = int(np.ceil(n_subjects / full_shard[0])) images_arr = store.create_array( "images", shape=(n_subjects, D, H, W), chunks=full_chunk, shards=full_shard, dtype=np.float32, ) labels_arr = store.create_array( "labels", shape=(n_subjects, D, H, W), chunks=full_chunk, shards=full_shard, dtype=np.int32, ) logger.info( "Created sharded Zarr3: shape=%s, chunks=%s, shards=%s (%d shard files)", (n_subjects, D, H, W), full_chunk, full_shard, n_shards, ) # Write volumes — parallel across shards. # Each shard is independent, so we can write to different shards # concurrently. Within a shard, writes are sequential. import concurrent.futures import os n_workers = min(os.cpu_count() or 1, n_shards, 8) def _write_shard_group(shard_idx): """Load and write all subjects belonging to one shard.""" start = shard_idx * full_shard[0] end = min(start + full_shard[0], n_subjects) for i in range(start, end): img_path, lbl_path = image_label_pairs[i] img = nib.load(img_path) lbl = nib.load(lbl_path) if conform: img = _conform_volume(img, target_shape, target_voxel_size) lbl = _conform_volume(lbl, target_shape, target_voxel_size) images_arr[i] = np.asarray(img.dataobj, dtype=np.float32)[:D, :H, :W] labels_arr[i] = np.asarray(lbl.dataobj, dtype=np.int32)[:D, :H, :W] return end - start logger.info( "Writing %d volumes across %d shards with %d workers...", n_subjects, n_shards, n_workers, ) with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as pool: futures = [pool.submit(_write_shard_group, s) for s in range(n_shards)] done = 0 for future in concurrent.futures.as_completed(futures): done += future.result() logger.info("Stored %d/%d volumes", done, n_subjects) # Store metadata store.attrs["n_subjects"] = n_subjects store.attrs["subject_ids"] = subject_ids store.attrs["volume_shape"] = list(target_shape) store.attrs["chunk_shape"] = list(chunk_shape) store.attrs["layout"] = "stacked" store.attrs["image_dtype"] = "float32" store.attrs["label_dtype"] = "int32" if conform: store.attrs["conformed"] = True store.attrs["target_shape"] = [int(x) for x in target_shape] store.attrs["target_voxel_size"] = [float(x) for x in target_voxel_size] else: store.attrs["conformed"] = False logger.info( "Zarr store created: %s (%d subjects, shape=%s)", output_path, n_subjects, target_shape, ) return output_path.resolve() def store_info(store_path: str | Path) -> dict[str, Any]: """Return store metadata without reading voxel data. Parameters ---------- store_path : str or Path Path to a Zarr store. Returns ------- dict Store metadata including n_subjects, volume_shape, subject_ids, etc. """ import zarr store = zarr.open_group(str(store_path), mode="r") return dict(store.attrs) def create_partition( store_path: str | Path, ratios: tuple[int, int, int] = (80, 10, 10), seed: int = 42, output_path: str | Path | None = None, ) -> Path: """Generate a partition index JSON file. Parameters ---------- store_path : str or Path Path to the Zarr store. ratios : tuple of int (train, val, test) percentages. seed : int Random seed for reproducibility. output_path : str or Path or None Output JSON path. None = ``_partition.json``. Returns ------- Path Path to the written partition JSON file. """ info = store_info(store_path) subject_ids = info["subject_ids"] n = len(subject_ids) rng = np.random.default_rng(seed) indices = rng.permutation(n) total = sum(ratios) n_train = int(n * ratios[0] / total) n_val = int(n * ratios[1] / total) train_ids = [subject_ids[i] for i in indices[:n_train]] val_ids = [subject_ids[i] for i in indices[n_train : n_train + n_val]] test_ids = [subject_ids[i] for i in indices[n_train + n_val :]] partition = { "seed": seed, "ratios": list(ratios), "n_subjects": n, "store_path": str(store_path), "partitions": { "train": train_ids, "val": val_ids, "test": test_ids, }, } if output_path is None: output_path = Path(str(store_path) + "_partition.json") output_path = Path(output_path) with open(output_path, "w") as f: json.dump(partition, f, indent=2) logger.info( "Partition created: %s (train=%d, val=%d, test=%d)", output_path, len(train_ids), len(val_ids), len(test_ids), ) return output_path def load_partition(partition_path: str | Path) -> dict[str, list[str]]: """Load a partition index and return ``{split: [subject_ids]}``. Parameters ---------- partition_path : str or Path Path to a partition JSON file. Returns ------- dict ``{"train": [...], "val": [...], "test": [...]}``. """ with open(partition_path) as f: data = json.load(f) return data["partitions"] ================================================ FILE: nobrainer/distributed_learning/dwc.py ================================================ import numpy as np # Distributed weight consolidation for Bayesian Deep Neural Networks # Implemented according to the: # McClure, Patrick, et al. Distributed weight consolidation: a brain segmentation case study. # Advances in neural information processing systems 31 (2018): 4093. def distributed_weight_consolidation(model_weights, model_priors): # model_weights is a list of weights of client-models; models = [model1, model2, model3...] # model_priors is a list of priors of client models sames as models num_layers = int(len(model_weights[0]) / 2.0) num_datasets = np.shape(model_weights)[0] consolidated_model = model_weights[0] mean_idx = [i for i in range(0, len(model_weights[0])) if i % 2 == 0] std_idx = [i for i in range(0, len(model_weights[0])) if i % 2 != 0] ep = 1e-5 for i in range(num_layers): num_1 = 0 num_2 = 0 den_1 = 0 den_2 = 0 for m in range(num_datasets): model = model_weights[m] prior = model_priors[m] mu_s = model[mean_idx[i]] mu_o = prior[mean_idx[i]] sig_s = model[std_idx[i]] sig_o = prior[std_idx[i]] d1 = np.power(sig_s, 2) + ep d2 = np.power(sig_o, 2) + ep num_1 = num_1 + (mu_s / d1) num_2 = num_2 + (mu_o / d2) den_1 = den_1 + (1.0 / d1) den_2 = den_2 + (1.0 / d2) consolidated_model[mean_idx[i]] = (num_1 - num_2) / (den_1 - den_2) consolidated_model[std_idx[i]] = 1 / (den_1 - den_2) return consolidated_model ================================================ FILE: nobrainer/experiment.py ================================================ """Experiment tracking: local file logger + optional Weights & Biases. Provides a unified interface for logging training metrics. The local logger always works (writes JSON lines + CSV to the output directory). W&B integration is optional and auto-detected. Usage:: from nobrainer.experiment import ExperimentTracker # Local-only (writes to output_dir/metrics.jsonl + metrics.csv) tracker = ExperimentTracker(output_dir="checkpoints/bvwn", config={...}) # With W&B (if wandb is installed and WANDB_API_KEY is set) tracker = ExperimentTracker( output_dir="checkpoints/bvwn", config={"lr": 1e-4, "filters": 96}, project="kwyk-reproduction", tags=["bvwn_multi_prior", "50-class"], ) for epoch in range(epochs): tracker.log({"epoch": epoch, "train_loss": loss, "val_dice": dice}) tracker.finish() """ from __future__ import annotations import csv import json import logging import os from pathlib import Path from typing import Any logger = logging.getLogger(__name__) class _LocalLogger: """Write metrics to JSON lines + CSV in the output directory.""" def __init__(self, output_dir: Path) -> None: self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.jsonl_path = self.output_dir / "metrics.jsonl" self.csv_path = self.output_dir / "metrics.csv" self._csv_writer = None self._csv_file = None self._fieldnames: list[str] | None = None def log(self, metrics: dict[str, Any]) -> None: # JSON lines (append) with open(self.jsonl_path, "a") as f: f.write(json.dumps(metrics, default=str) + "\n") # CSV (create header on first call, append rows) if self._csv_writer is None: self._fieldnames = list(metrics.keys()) self._csv_file = open(self.csv_path, "w", newline="") self._csv_writer = csv.DictWriter( self._csv_file, fieldnames=self._fieldnames, extrasaction="ignore" ) self._csv_writer.writeheader() self._csv_writer.writerow(metrics) self._csv_file.flush() def log_config(self, config: dict[str, Any]) -> None: with open(self.output_dir / "config.json", "w") as f: json.dump(config, f, indent=2, default=str) def finish(self) -> None: if self._csv_file is not None: self._csv_file.close() self._csv_file = None self._csv_writer = None class _WandbLogger: """Log metrics to Weights & Biases.""" def __init__( self, config: dict[str, Any], project: str | None, name: str | None, tags: list[str] | None, ) -> None: import wandb self._wandb = wandb self._run = wandb.init( project=project or "nobrainer", name=name, config=config, tags=tags, reinit=True, ) def log(self, metrics: dict[str, Any]) -> None: self._wandb.log(metrics) def log_config(self, config: dict[str, Any]) -> None: self._run.config.update(config, allow_val_change=True) def finish(self) -> None: self._wandb.finish() class ExperimentTracker: """Unified experiment tracker with local + optional W&B backends. The local backend always runs, writing ``metrics.jsonl``, ``metrics.csv``, and ``config.json`` to *output_dir*. W&B is activated when: 1. ``wandb`` is installed, AND 2. ``WANDB_API_KEY`` is set or ``use_wandb=True`` is passed. Parameters ---------- output_dir : str or Path Directory for local metric files. config : dict, optional Hyperparameters / configuration to log. project : str, optional W&B project name (default ``"nobrainer"``). name : str, optional W&B run name. tags : list of str, optional W&B run tags. use_wandb : bool or None Force W&B on/off. None = auto-detect (use if installed + key set). """ def __init__( self, output_dir: str | Path, config: dict[str, Any] | None = None, project: str | None = None, name: str | None = None, tags: list[str] | None = None, use_wandb: bool | None = None, ) -> None: self._backends: list[Any] = [] # Local logger (always active) local = _LocalLogger(Path(output_dir)) self._backends.append(local) # Save config locally if config: local.log_config(config) # W&B (optional) if use_wandb is None: use_wandb = ( os.environ.get("WANDB_API_KEY") is not None or os.environ.get("WANDB_MODE") == "offline" ) if use_wandb: try: wb = _WandbLogger( config=config or {}, project=project, name=name, tags=tags, ) self._backends.append(wb) logger.info("W&B tracking enabled (project=%s)", project) except Exception as exc: logger.warning("W&B init failed: %s — using local only", exc) backend_names = [type(b).__name__ for b in self._backends] logger.info("Experiment tracking: %s", ", ".join(backend_names)) def log(self, metrics: dict[str, Any]) -> None: """Log a dict of metrics to all backends.""" for backend in self._backends: backend.log(metrics) def log_config(self, config: dict[str, Any]) -> None: """Log/update configuration to all backends.""" for backend in self._backends: backend.log_config(config) def finish(self) -> None: """Finalize all backends (flush files, end W&B run).""" for backend in self._backends: backend.finish() def callback(self, **extra_fields) -> callable: """Return a training callback that logs epoch metrics. The returned callable has signature ``(epoch, logs, model)`` — matching the callback protocol in :func:`nobrainer.training.fit` and :class:`Segmentation.fit`. Parameters ---------- **extra_fields Extra key-value pairs included in every log entry (e.g., ``variant="bvwn_multi_prior"``). Example:: tracker = ExperimentTracker("checkpoints/bvwn", config={...}) seg.fit(ds, epochs=50, callbacks=[tracker.callback(variant="ssd")]) tracker.finish() """ def _cb(epoch: int, logs: dict, model: Any) -> None: self.log({"epoch": epoch, **logs, **extra_fields}) return _cb ================================================ FILE: nobrainer/gpu.py ================================================ """GPU utilities: device detection, memory profiling, batch size optimization. Examples -------- Auto-select the best batch size for a model and block shape:: from nobrainer.gpu import auto_batch_size, gpu_info info = gpu_info() print(info) # [{'name': 'Tesla T4', 'memory_gb': 15.1, 'id': 0}, ...] batch_size = auto_batch_size( model=my_model, block_shape=(32, 32, 32), n_classes=2, target_memory_fraction=0.85, ) print(f"Optimal batch size: {batch_size}") Scale batch size for multi-GPU:: from nobrainer.gpu import scale_for_multi_gpu effective_batch, per_gpu_batch, n_gpus = scale_for_multi_gpu( base_batch_size=32, block_shape=(32, 32, 32), ) # On 4x T4: effective=128, per_gpu=32, n_gpus=4 """ from __future__ import annotations import logging from typing import Any import torch import torch.nn as nn logger = logging.getLogger(__name__) def get_device() -> torch.device: """Select the best available device: CUDA > MPS > CPU.""" if torch.cuda.is_available(): return torch.device("cuda") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def gpu_count() -> int: """Return the number of CUDA GPUs available (0 if none).""" if torch.cuda.is_available(): return torch.cuda.device_count() return 0 def gpu_info() -> list[dict[str, Any]]: """Return a list of dicts with GPU name, memory, and id. Returns an empty list if no CUDA GPUs are available. """ info = [] if not torch.cuda.is_available(): return info for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) info.append( { "id": i, "name": props.name, "memory_gb": round(props.total_memory / 1e9, 1), "compute_capability": f"{props.major}.{props.minor}", } ) return info def _estimate_memory_per_sample( model: nn.Module, block_shape: tuple[int, int, int], n_classes: int = 2, in_channels: int = 1, dtype: torch.dtype = torch.float32, forward_kwargs: dict | None = None, ) -> float: """Estimate GPU memory (bytes) for one training sample. Runs a forward + backward pass with batch_size=1 and measures the peak allocated memory. The model is moved to GPU temporarily. Parameters ---------- model : nn.Module Model to profile. block_shape : tuple of int Spatial dimensions of one input patch. n_classes : int Number of output classes. in_channels : int Number of input channels. dtype : torch.dtype Input data type. Returns ------- float Estimated bytes per sample (forward + backward + optimizer overhead). """ if forward_kwargs is None: forward_kwargs = {} if not torch.cuda.is_available(): raise RuntimeError("CUDA required for memory estimation") device = torch.device("cuda") model = model.to(device) model.train() torch.cuda.reset_peak_memory_stats(device) torch.cuda.empty_cache() baseline = torch.cuda.memory_allocated(device) x = torch.randn(1, in_channels, *block_shape, device=device, dtype=dtype) labels = torch.randint(0, n_classes, (1, *block_shape), device=device) # Pass forward_kwargs if model accepts them (e.g. mc_vwn, mc_dropout) try: out = model(x, **forward_kwargs) except TypeError: out = model(x) loss = nn.CrossEntropyLoss()(out, labels) loss.backward() peak = torch.cuda.max_memory_allocated(device) - baseline # Clean up model.zero_grad(set_to_none=True) del x, labels, out, loss torch.cuda.empty_cache() model.cpu() return float(peak) def auto_batch_size( model: nn.Module, block_shape: tuple[int, int, int], n_classes: int = 2, in_channels: int = 1, target_memory_fraction: float = 0.85, gpu_id: int = 0, min_batch: int = 1, max_batch: int = 512, forward_kwargs: dict | None = None, ) -> int: """Estimate the largest batch size that fits in GPU memory. Profiles one sample, then scales to fill ``target_memory_fraction`` of the GPU. Parameters ---------- model : nn.Module Model to profile (will be temporarily moved to GPU). block_shape : tuple of int Spatial dimensions ``(D, H, W)`` of one input patch. n_classes : int Number of output classes. in_channels : int Number of input channels. target_memory_fraction : float Fraction of total GPU memory to target (default 0.85). gpu_id : int Which GPU to profile. min_batch : int Minimum batch size to return. max_batch : int Maximum batch size to return. Returns ------- int Recommended batch size for one GPU. """ if not torch.cuda.is_available(): logger.warning("No CUDA — returning min_batch=%d", min_batch) return min_batch total_mem = torch.cuda.get_device_properties(gpu_id).total_memory target_mem = total_mem * target_memory_fraction try: mem_per_sample = _estimate_memory_per_sample( model, block_shape, n_classes, in_channels, forward_kwargs=forward_kwargs, ) except RuntimeError as e: logger.warning("Memory estimation failed: %s — returning min_batch", e) return min_batch # Account for ~20% overhead (optimizer state, fragmentation) effective_per_sample = mem_per_sample * 1.2 batch = int(target_mem / effective_per_sample) batch = max(min_batch, min(batch, max_batch)) logger.info( "auto_batch_size: %.1f GB total, %.1f MB/sample, " "target %.0f%% → batch_size=%d", total_mem / 1e9, mem_per_sample / 1e6, target_memory_fraction * 100, batch, ) return batch def scale_for_multi_gpu( base_batch_size: int, block_shape: tuple[int, int, int] | None = None, model: nn.Module | None = None, n_classes: int = 2, target_memory_fraction: float = 0.85, ) -> tuple[int, int, int]: """Scale batch size for multi-GPU training. If ``model`` is provided, uses :func:`auto_batch_size` to determine the per-GPU batch size based on actual memory profiling. Otherwise, divides ``base_batch_size`` evenly across available GPUs. Parameters ---------- base_batch_size : int Desired effective (global) batch size. block_shape : tuple of int, optional Spatial dimensions for memory profiling. model : nn.Module, optional Model for memory profiling. If None, uses simple division. n_classes : int Number of output classes (for profiling). target_memory_fraction : float Target GPU memory fraction (for profiling). Returns ------- effective_batch : int Total batch size across all GPUs. per_gpu_batch : int Batch size per GPU. n_gpus : int Number of GPUs to use. """ n_gpus = gpu_count() if n_gpus == 0: return base_batch_size, base_batch_size, 0 if model is not None and block_shape is not None: per_gpu = auto_batch_size( model, block_shape, n_classes=n_classes, target_memory_fraction=target_memory_fraction, ) else: per_gpu = max(1, base_batch_size // n_gpus) effective = per_gpu * n_gpus logger.info( "Multi-GPU scaling: %d GPUs × %d per-GPU = %d effective batch", n_gpus, per_gpu, effective, ) return effective, per_gpu, n_gpus ================================================ FILE: nobrainer/io.py ================================================ """Input/output utilities for nobrainer (PyTorch, no TensorFlow).""" from __future__ import annotations import csv import hashlib from pathlib import Path import struct from typing import Any import h5py import nibabel as nib import numpy as np import torch import torch.nn as nn # --------------------------------------------------------------------------- # CSV helpers (no TF dependency) # --------------------------------------------------------------------------- def read_csv( filepath: str | Path, skip_header: bool = True, delimiter: str = "," ) -> list: """Return list of tuples from a CSV file.""" with open(filepath, newline="") as f: reader = csv.reader(f, delimiter=delimiter) if skip_header: next(reader) return [tuple(row) for row in reader] def read_mapping( filepath: str | Path, skip_header: bool = True, delimiter: str = "," ) -> dict[str, str]: """Read CSV as dict; first column → keys, second → values.""" rows = read_csv(filepath, skip_header=skip_header, delimiter=delimiter) return {row[0]: row[1] for row in rows} # --------------------------------------------------------------------------- # TFRecord conversion (T022) # --------------------------------------------------------------------------- def _compute_sha256(path: str | Path) -> str: h = hashlib.sha256() with open(path, "rb") as f: for chunk in iter(lambda: f.read(65536), b""): h.update(chunk) return h.hexdigest() def _parse_tfrecord_file(path: str | Path): """Yield raw TFRecord byte strings from a .tfrecord file. TFRecord format: [length:uint64][masked_crc32:uint32][data][masked_crc32:uint32] """ with open(path, "rb") as f: while True: header = f.read(12) if not header: break (length,) = struct.unpack_from(" list[str]: """Convert TFRecord files to NIfTI or HDF5. Uses the ``tfrecord`` PyPI package — no TensorFlow required. Parameters ---------- tfrecord_paths : list Paths to ``.tfrecord`` files. output_dir : str or Path Directory where converted files are written. volume_shape : tuple or None Expected shape ``(D, H, W, C)`` of the stored arrays. Used to validate/reshape the parsed tensors. output_format : str ``"nifti"`` (writes ``.nii.gz``) or ``"hdf5"`` (writes ``.h5``). affine : ndarray or None 4×4 affine matrix for NIfTI files. Defaults to identity. verify_checksum : bool Compute SHA-256 of each output file after writing. Returns ------- list of str Paths to converted output files. """ import tfrecord # noqa: F401 (optional dep) from tfrecord.reader import tfrecord_loader output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) if affine is None: affine = np.eye(4) out_paths: list[str] = [] for rec_path in tfrecord_paths: rec_path = Path(rec_path) loader = tfrecord_loader( str(rec_path), index_path=None, description={"volume": "byte", "label": "byte"}, ) for i, record in enumerate(loader): volume_bytes = record.get("volume") or record.get("image") label_bytes = record.get("label") vol_arr = np.frombuffer(volume_bytes, dtype=np.float32) if volume_shape is not None: vol_arr = vol_arr.reshape(volume_shape) # TF stores (D,H,W,C), PyTorch wants (C,D,H,W) if vol_arr.ndim == 4: vol_arr = np.transpose(vol_arr, (3, 0, 1, 2)) stem = rec_path.stem if output_format == "hdf5": out_path = output_dir / f"{stem}_{i:04d}.h5" with h5py.File(out_path, "w") as hf: hf.create_dataset("volume", data=vol_arr, compression="gzip") if label_bytes is not None: lbl_arr = np.frombuffer(label_bytes, dtype=np.float32) if volume_shape is not None: lbl_arr = lbl_arr.reshape(volume_shape) if lbl_arr.ndim == 4: lbl_arr = np.transpose(lbl_arr, (3, 0, 1, 2)) hf.create_dataset("label", data=lbl_arr, compression="gzip") else: # NIfTI: use first channel as spatial volume spatial = vol_arr[0] if vol_arr.ndim == 4 else vol_arr img = nib.Nifti1Image(spatial.astype(np.float32), affine) out_path = output_dir / f"{stem}_{i:04d}.nii.gz" nib.save(img, str(out_path)) out_paths.append(str(out_path)) if verify_checksum: _compute_sha256(out_path) # validates file integrity return out_paths # --------------------------------------------------------------------------- # Weight conversion: TF Keras H5 → PyTorch (T024) # --------------------------------------------------------------------------- # Mapping patterns: (keras_layer_keyword, param_suffix) → pytorch_param_name _CONV_MAPPING = { "kernel": "weight", "bias": "bias", "gamma": "weight", # BatchNorm scale "beta": "bias", # BatchNorm shift "moving_mean": "running_mean", "moving_variance": "running_var", } def _keras_conv3d_to_pytorch(w: np.ndarray) -> np.ndarray: """Transpose Conv3D weights from Keras (D,H,W,Cin,Cout) → PyTorch (Cout,Cin,D,H,W).""" if w.ndim == 5: return np.transpose(w, (4, 3, 0, 1, 2)) return w def convert_weights( h5_path: str | Path, pt_model: nn.Module, layer_mapping: dict[str, str] | None = None, output_path: str | Path | None = None, verify: bool = False, ) -> dict[str, torch.Tensor]: """Load Keras ``.h5`` weights and map them to a PyTorch model. No TensorFlow is required; weights are read directly with ``h5py``. Parameters ---------- h5_path : str or Path Path to the Keras ``.h5`` weight file. pt_model : nn.Module Target PyTorch model whose ``state_dict`` will receive the weights. layer_mapping : dict or None ``{keras_layer_name: pytorch_submodule_name}`` mapping. When ``None``, an automatic heuristic attempts to match by index. output_path : str or Path or None If provided, save the converted state dict to ``.pth``. verify : bool Run a brief forward-pass verification after loading (raises if shapes mismatch). Returns ------- dict The loaded (possibly partial) state dict. """ h5_path = Path(h5_path) state = pt_model.state_dict() new_state: dict[str, torch.Tensor] = {} with h5py.File(h5_path, "r") as hf: # Traverse all datasets in the H5 file def _collect(name: str, obj: Any) -> None: if not isinstance(obj, h5py.Dataset): return w = obj[()] # numpy array # Apply weight transposition for Conv3D kernels if "kernel" in name and w.ndim == 5: w = _keras_conv3d_to_pytorch(w) # Determine target PyTorch parameter name pt_name = _map_name(name, layer_mapping, state) if pt_name is not None and pt_name in state: tensor = torch.from_numpy(w.copy()) if tensor.shape == state[pt_name].shape: new_state[pt_name] = tensor hf.visititems(_collect) # Load matched weights; keep existing for unmatched combined = {**state, **new_state} pt_model.load_state_dict(combined, strict=False) if output_path is not None: torch.save(combined, str(output_path)) if verify: pt_model.eval() dummy = torch.zeros(1, 1, 32, 32, 32) with torch.no_grad(): _ = pt_model(dummy) return new_state def _map_name( h5_name: str, mapping: dict[str, str] | None, state: dict[str, torch.Tensor], ) -> str | None: """Attempt to resolve an H5 dataset path to a PyTorch state-dict key.""" # Simple heuristic: look for a state-dict key that contains the leaf name parts = h5_name.replace("/", ".").split(".") leaf = parts[-1] pt_leaf = _CONV_MAPPING.get(leaf, leaf) if mapping: for k, v in mapping.items(): if k in h5_name: candidate = f"{v}.{pt_leaf}" if candidate in state: return candidate # Fallback: direct match candidate = ".".join(parts[:-1] + [pt_leaf]) return candidate if candidate in state else None # --------------------------------------------------------------------------- # Zarr v3 conversion (requires [zarr] extras) # --------------------------------------------------------------------------- def nifti_to_zarr( input_path: str | Path, output_path: str | Path, chunk_shape: tuple[int, int, int] = (64, 64, 64), shard_shape: tuple[int, int, int] | None = None, compressor: str = "blosc", levels: int = 1, ) -> Path: """Convert a NIfTI file to a sharded Zarr v3 store. Uses ``niizarr.nii2zarr`` for NIfTI-Zarr specification compliance (NIfTI header, OME multiscale metadata) and adds nobrainer provenance. Parameters ---------- input_path : path to .nii or .nii.gz file output_path : path for the output .zarr directory chunk_shape : inner chunk dimensions shard_shape : outer shard dimensions; None lets niizarr choose compressor : compression codec ("blosc" or "zlib") levels : number of resolution levels (1 = single level, -1 = auto) Returns ------- Path to the created .zarr store """ import datetime import zarr import nobrainer img = nib.load(str(input_path)) output_path = Path(output_path) try: import niizarr niizarr.nii2zarr( str(input_path), str(output_path), chunk=chunk_shape, shard=shard_shape, nb_levels=levels, compressor=compressor, zarr_version=3, ) except ImportError: # Fallback: manual Zarr v3 creation without niizarr arr = np.asarray(img.dataobj, dtype=np.float32) clamped_chunk = tuple(min(c, s) for c, s in zip(chunk_shape, arr.shape)) if shard_shape is None: eff_shard = tuple(min(c * 2, s) for c, s in zip(clamped_chunk, arr.shape)) else: eff_shard = tuple(max(c, s) for c, s in zip(clamped_chunk, shard_shape)) store = zarr.open_group(str(output_path), mode="w", zarr_format=3) store.create_array("0", data=arr, chunks=clamped_chunk, shards=eff_shard) if levels > 1: from scipy.ndimage import zoom for lvl in range(1, levels): factor = 1 / 2**lvl down = zoom(arr, factor, order=1).astype(arr.dtype) lc = tuple(min(c, s) for c, s in zip(clamped_chunk, down.shape)) ls = tuple(min(c * 2, s) for c, s in zip(lc, down.shape)) store.create_array(str(lvl), data=down, chunks=lc, shards=ls) store.attrs["nifti_affine"] = img.affine.tolist() # Store provenance in group attrs store = zarr.open_group(str(output_path), mode="r+") store.attrs["nobrainer_provenance"] = { "source_file": str(Path(input_path).name), "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), "tool": "nobrainer.io.nifti_to_zarr", "nobrainer_version": nobrainer.__version__, "chunk_shape": list(chunk_shape), "levels": levels, } return output_path def zarr_to_nifti( input_path: str | Path, output_path: str | Path, level: int = 0, ) -> Path: """Convert a Zarr v3 store back to NIfTI. Tries ``niizarr.zarr2nii`` first for NIfTI-Zarr spec compliance. Falls back to reading the array + stored affine if niizarr is unavailable or fails. Parameters ---------- input_path : path to .zarr directory output_path : path for the output .nii.gz file level : resolution level to export (0 = full resolution) Returns ------- Path to the created NIfTI file """ import zarr output_path = Path(output_path) # Try niizarr first try: import niizarr img = niizarr.zarr2nii(str(input_path), level=level) if img.affine is not None: nib.save(img, str(output_path)) return output_path except Exception: pass # Fallback: manual read store = zarr.open_group(str(input_path), mode="r") arr = np.asarray(store[str(level)]) affine = np.array(store.attrs.get("nifti_affine", np.eye(4).tolist())) img = nib.Nifti1Image(arr.astype(np.float32), affine) nib.save(img, str(output_path)) return output_path __all__ = [ "read_csv", "read_mapping", "convert_tfrecords", "convert_weights", "nifti_to_zarr", "zarr_to_nifti", ] ================================================ FILE: nobrainer/layers/InstanceNorm.py ================================================ import logging from ..layers.groupnorm import GroupNormalization class InstanceNormalization(GroupNormalization): """Instance normalization layer. Instance Normalization is an specific case of ```GroupNormalization```since it normalizes all features of one channel. The Groupsize is equal to the channel size. Empirically, its accuracy is more stable than batch norm in a wide range of small batch sizes, if learning rate is adjusted linearly with batch sizes. Arguments axis: Integer, the axis that should be normalized. epsilon: Small float added to variance to avoid dividing by zero. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. beta_initializer: Initializer for the beta weight. gamma_initializer: Initializer for the gamma weight. beta_regularizer: Optional regularizer for the beta weight. gamma_regularizer: Optional regularizer for the gamma weight. beta_constraint: Optional constraint for the beta weight. gamma_constraint: Optional constraint for the gamma weight. Input shape Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. Output shape Same shape as input. References - [Instance Normalization: The Missing Ingredient for Fast Stylization] (https://arxiv.org/abs/1607.08022) """ def __init__(self, **kwargs): if "groups" in kwargs: logging.warning("The given value for groups will be overwritten.") kwargs["groups"] = -1 super().__init__(**kwargs) ================================================ FILE: nobrainer/layers/__init__.py ================================================ from .bernoulli_dropout import BernoulliDropout from .concrete_dropout import ConcreteDropout from .gaussian_dropout import GaussianDropout from .maxpool4d import MaxPool4D from .padding import ZeroPadding3DChannels __all__ = [ "BernoulliDropout", "ConcreteDropout", "GaussianDropout", "MaxPool4D", "ZeroPadding3DChannels", ] ================================================ FILE: nobrainer/layers/bernoulli_dropout.py ================================================ """Bernoulli dropout layer for PyTorch.""" import torch import torch.nn as nn class BernoulliDropout(nn.Module): """Bernoulli dropout layer. Multiplies input by a Bernoulli mask sampled with keep probability ``1 - rate``. When ``scale_during_training`` is ``True`` the output is rescaled by ``1 / keep_prob`` so that the expected value is preserved (inverted dropout). When it is ``False`` the raw Bernoulli mask is applied and the output is scaled by ``keep_prob`` at test time. Parameters ---------- rate : float Drop probability (0 ≤ rate < 1). is_monte_carlo : bool When ``True`` the stochastic mask is applied regardless of ``training`` mode (enables MC-Dropout inference). scale_during_training : bool When ``True`` uses inverted dropout (scale at train time). When ``False`` scales at test time instead. seed : int or None Optional RNG seed (used to create a per-layer Generator). References ---------- Dropout: A Simple Way to Prevent Neural Networks from Overfitting. N. Srivastava et al., JMLR 2014. """ def __init__( self, rate: float, is_monte_carlo: bool, scale_during_training: bool = True, seed: int | None = None, ) -> None: super().__init__() if not 0.0 <= rate < 1.0: raise ValueError(f"rate must be in [0, 1), got {rate}") self.rate = rate self.is_monte_carlo = is_monte_carlo self.scale_during_training = scale_during_training self.keep_prob = 1.0 - rate self._generator: torch.Generator | None = None if seed is not None: self._generator = torch.Generator() self._generator.manual_seed(seed) def forward(self, x: torch.Tensor) -> torch.Tensor: apply_mask = self.is_monte_carlo or self.training if apply_mask: mask = torch.bernoulli( torch.full_like(x, self.keep_prob), generator=self._generator ) out = x * mask return out / self.keep_prob if self.scale_during_training else out # deterministic path return x if self.scale_during_training else self.keep_prob * x def extra_repr(self) -> str: return ( f"rate={self.rate}, is_monte_carlo={self.is_monte_carlo}, " f"scale_during_training={self.scale_during_training}" ) ================================================ FILE: nobrainer/layers/concrete_dropout.py ================================================ """Concrete Dropout layer for PyTorch.""" import math import torch import torch.nn as nn class ConcreteDropout(nn.Module): """Concrete (relaxed Bernoulli) dropout layer. Learns a per-channel drop probability ``p_post`` end-to-end via a differentiable relaxation of the Bernoulli mask. A KL-divergence regulariser between ``p_post`` and a fixed prior ``p_prior = 0.5`` is accumulated in ``self.kl_loss`` after each forward call. Parameters ---------- in_channels : int Number of input channels (last dimension of the input tensor). is_monte_carlo : bool When ``True`` the stochastic concrete mask is applied regardless of ``training`` mode. temperature : float Temperature of the concrete distribution (lower → more binary). use_expectation : bool At test time, use ``x * p_post`` instead of the identity. scale_factor : float Normalisation factor for the KL regulariser. seed : int or None Optional RNG seed. References ---------- Concrete Dropout. Y. Gal, J. Hron & A. Kendall, NeurIPS 2017. """ def __init__( self, in_channels: int, is_monte_carlo: bool = False, temperature: float = 0.02, use_expectation: bool = False, scale_factor: float = 1.0, seed: int | None = None, ) -> None: super().__init__() self.is_monte_carlo = is_monte_carlo self.temperature = temperature self.use_expectation = use_expectation self.scale_factor = scale_factor self._generator: torch.Generator | None = None if seed is not None: self._generator = torch.Generator() self._generator.manual_seed(seed) # Learnable drop probability (per channel), initialised near 0.9 self.p_logit = nn.Parameter(torch.full((in_channels,), math.log(0.9 / 0.1))) # Fixed prior p = 0.5 → logit = 0 self.register_buffer("p_prior", torch.full((in_channels,), 0.5)) self.kl_loss: torch.Tensor = torch.tensor(0.0) @property def p_post(self) -> torch.Tensor: """Dropout probability clipped to (0.05, 0.95).""" return torch.sigmoid(self.p_logit).clamp(0.05, 0.95) def forward(self, x: torch.Tensor) -> torch.Tensor: apply_mask = self.is_monte_carlo or self.training if apply_mask: out = self._apply_concrete(x) else: out = x * self.p_post if self.use_expectation else x self.kl_loss = self._kl_divergence() return out def _apply_concrete(self, x: torch.Tensor) -> torch.Tensor: eps = torch.finfo(x.dtype).eps p = self.p_post # (C,) noise = torch.rand( x.shape, dtype=x.dtype, device=x.device, generator=self._generator ).clamp(eps, 1.0 - eps) z = torch.sigmoid( ( torch.log(p + eps) - torch.log(1.0 - p + eps) + torch.log(noise) - torch.log(1.0 - noise) ) / self.temperature ) return x * z def _kl_divergence(self) -> torch.Tensor: eps = 1e-7 p = self.p_post pr = self.p_prior kl = p * (torch.log(p + eps) - torch.log(pr + eps)) + (1 - p) * ( torch.log(1 - p + eps) - torch.log(1 - pr + eps) ) return kl.sum() / self.scale_factor def extra_repr(self) -> str: return ( f"is_monte_carlo={self.is_monte_carlo}, temperature={self.temperature}, " f"scale_factor={self.scale_factor}" ) ================================================ FILE: nobrainer/layers/gaussian_dropout.py ================================================ """Gaussian dropout layer for PyTorch.""" import math import torch import torch.nn as nn class GaussianDropout(nn.Module): """Gaussian (multiplicative) dropout layer. Multiplies the input by noise sampled from ``Normal(1, σ²)`` where σ is derived from ``rate``. When ``scale_during_training`` is ``True``, σ = sqrt(rate / (1 - rate)) (variance-preserving during training); otherwise σ = sqrt(rate * (1 - rate)). Parameters ---------- rate : float Drop probability (0 ≤ rate < 1). is_monte_carlo : bool When ``True``, noise is applied regardless of ``training`` mode. scale_during_training : bool Selects which σ formula is used (see above). seed : int or None Optional RNG seed. References ---------- Dropout: A Simple Way to Prevent Neural Networks from Overfitting. N. Srivastava et al., JMLR 2014. """ def __init__( self, rate: float, is_monte_carlo: bool, scale_during_training: bool = True, seed: int | None = None, ) -> None: super().__init__() if not 0.0 <= rate < 1.0: raise ValueError(f"rate must be in [0, 1), got {rate}") self.rate = rate self.is_monte_carlo = is_monte_carlo self.scale_during_training = scale_during_training self._generator: torch.Generator | None = None if seed is not None: self._generator = torch.Generator() self._generator.manual_seed(seed) if scale_during_training: self._stddev = math.sqrt(rate / (1.0 - rate)) else: self._stddev = math.sqrt(rate * (1.0 - rate)) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_monte_carlo or self.training: noise = torch.randn_like(x, generator=self._generator) * self._stddev + 1.0 return x * noise return x def extra_repr(self) -> str: return ( f"rate={self.rate}, is_monte_carlo={self.is_monte_carlo}, " f"scale_during_training={self.scale_during_training}" ) ================================================ FILE: nobrainer/layers/maxpool4d.py ================================================ """MaxPool4D layer for PyTorch. Implements 4-D max-pooling (N, C, V, D, H, W) by treating the volume dimension V as a batch dimension and applying ``nn.MaxPool3d`` over (D, H, W). This avoids the need for a custom CUDA kernel. """ import torch import torch.nn as nn class MaxPool4D(nn.Module): """Max-pooling over 4 spatial dimensions. Expects input of shape ``(N, C, V, D, H, W)`` and applies ``kernel_size`` / ``stride`` / ``padding`` along the last 3 dimensions (D, H, W). The volume dimension V is reduced with ``pool_v`` if ``> 1``. Parameters ---------- kernel_size : int or tuple Kernel size for the (D, H, W) axes. stride : int or tuple or None Stride; defaults to ``kernel_size``. padding : int or tuple Zero-padding added to all spatial sides. pool_v : int Max-pool kernel size along the volume (V) axis. ``1`` leaves V unchanged. """ def __init__( self, kernel_size: int | tuple[int, int, int], stride: int | tuple[int, int, int] | None = None, padding: int | tuple[int, int, int] = 0, pool_v: int = 1, ) -> None: super().__init__() self.pool3d = nn.MaxPool3d(kernel_size, stride=stride, padding=padding) self.pool_v = pool_v def forward(self, x: torch.Tensor) -> torch.Tensor: if x.dim() != 6: raise ValueError( f"MaxPool4D expects 6-D input (N, C, V, D, H, W), got {x.dim()}-D" ) N, C, V, D, H, W = x.shape # Merge batch and volume dims so MaxPool3d sees (N*V, C, D, H, W) out = self.pool3d(x.view(N * V, C, D, H, W)) _, _, D2, H2, W2 = out.shape out = out.view(N, C, V, D2, H2, W2) if self.pool_v > 1: # Max over V with stride pool_v (using unfold for non-overlapping) out = out.unfold(2, self.pool_v, self.pool_v).amax(dim=-1) return out def extra_repr(self) -> str: return f"pool3d={self.pool3d}, pool_v={self.pool_v}" ================================================ FILE: nobrainer/layers/padding.py ================================================ """Custom padding layers for nobrainer (PyTorch).""" import torch import torch.nn as nn import torch.nn.functional as F class ZeroPadding3DChannels(nn.Module): """Pad the channel dimension of a 5-D tensor symmetrically with zeros. Expects input of shape ``(N, C, D, H, W)`` and pads ``C`` by ``padding`` on each side, yielding ``(N, C + 2*padding, D, H, W)``. Parameters ---------- padding : int Number of zero channels to prepend and append. """ def __init__(self, padding: int) -> None: super().__init__() self.padding = padding def forward(self, x: torch.Tensor) -> torch.Tensor: # F.pad pads in reverse dim order; last two entries pad dim 1 (C) return F.pad(x, (0, 0, 0, 0, 0, 0, self.padding, self.padding)) def extra_repr(self) -> str: return f"padding={self.padding}" ================================================ FILE: nobrainer/layers/tests/__init__.py ================================================ ================================================ FILE: nobrainer/losses.py ================================================ """Loss functions for 3-D semantic segmentation (PyTorch / MONAI).""" from __future__ import annotations from monai.losses import DiceLoss, GeneralizedDiceLoss, TverskyLoss import torch # --------------------------------------------------------------------------- # Convenience factory functions # --------------------------------------------------------------------------- def dice( sigmoid: bool = False, softmax: bool = False, squared_pred: bool = False, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, **kwargs, ) -> DiceLoss: """Return a MONAI ``DiceLoss`` instance. Parameters ---------- sigmoid : bool Apply sigmoid to predictions before computing Dice. softmax : bool Apply softmax to predictions before computing Dice. squared_pred : bool Use squared predictions in the denominator. smooth_nr, smooth_dr : float Numerator/denominator smoothing to avoid division by zero. **kwargs Extra keyword arguments forwarded to ``monai.losses.DiceLoss``. """ return DiceLoss( sigmoid=sigmoid, softmax=softmax, squared_pred=squared_pred, smooth_nr=smooth_nr, smooth_dr=smooth_dr, **kwargs, ) def generalized_dice( sigmoid: bool = False, softmax: bool = False, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, **kwargs, ) -> GeneralizedDiceLoss: """Return a MONAI ``GeneralizedDiceLoss`` instance.""" return GeneralizedDiceLoss( sigmoid=sigmoid, softmax=softmax, smooth_nr=smooth_nr, smooth_dr=smooth_dr, **kwargs, ) def jaccard( sigmoid: bool = False, softmax: bool = False, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, **kwargs, ) -> DiceLoss: """Return a Dice loss configured for Jaccard (IoU) computation. The Jaccard index equals ``intersection / union``; setting ``jaccard=True`` in MONAI's ``DiceLoss`` switches the denominator accordingly. """ return DiceLoss( sigmoid=sigmoid, softmax=softmax, jaccard=True, smooth_nr=smooth_nr, smooth_dr=smooth_dr, **kwargs, ) def tversky( alpha: float = 0.3, beta: float = 0.7, sigmoid: bool = False, softmax: bool = False, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, **kwargs, ) -> TverskyLoss: """Return a MONAI ``TverskyLoss`` instance. Parameters ---------- alpha : float Weight of false positives. beta : float Weight of false negatives. """ return TverskyLoss( alpha=alpha, beta=beta, sigmoid=sigmoid, softmax=softmax, smooth_nr=smooth_nr, smooth_dr=smooth_dr, **kwargs, ) # --------------------------------------------------------------------------- # Stubs — implemented in US2 (elbo) and US3 (wasserstein) # --------------------------------------------------------------------------- def elbo( model: torch.nn.Module, kl_weight: float, reconstruction_loss: torch.Tensor, ) -> torch.Tensor: """Compute ELBO = reconstruction_loss + kl_weight * KL. The KL term is accumulated by Pyro sampling during the forward pass of Bayesian modules (:class:`~nobrainer.models.bayesian.layers.BayesianConv3d` and :class:`~nobrainer.models.bayesian.layers.BayesianLinear`). Parameters ---------- model : nn.Module A model with one or more Bayesian layers whose ``.kl`` attributes have been populated by a recent forward pass. kl_weight : float Scalar multiplier for the KL divergence term (often ``1 / N_data`` or ``1 / N_batches``). reconstruction_loss : torch.Tensor Scalar reconstruction loss (e.g., Dice or cross-entropy) already computed for the current batch. Returns ------- torch.Tensor Scalar ELBO = reconstruction_loss + kl_weight * KL. """ from .models.bayesian.utils import accumulate_kl kl = accumulate_kl(model) return reconstruction_loss + kl_weight * kl def wasserstein(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: """Wasserstein critic loss: ``E[D(fake)] - E[D(real)]``. Parameters ---------- y_true : torch.Tensor Critic scores for real samples, shape ``(N,)`` or ``(N, 1)``. y_pred : torch.Tensor Critic scores for fake samples, shape ``(N,)`` or ``(N, 1)``. Returns ------- torch.Tensor Scalar Wasserstein critic loss (minimised by the discriminator). """ return y_pred.mean() - y_true.mean() def gradient_penalty( discriminator: torch.nn.Module, real: torch.Tensor, fake: torch.Tensor, lambda_gp: float = 10.0, ) -> torch.Tensor: """WGAN-GP gradient penalty. Interpolates between ``real`` and ``fake`` samples and penalises the discriminator gradient norm for deviating from 1. Parameters ---------- discriminator : nn.Module The discriminator / critic network. real : torch.Tensor Real samples, shape ``(N, C, D, H, W)``. fake : torch.Tensor Generated samples, same shape as ``real``. lambda_gp : float Penalty weight (default 10, standard WGAN-GP value). Returns ------- torch.Tensor Scalar gradient penalty term. """ b = real.size(0) eps = torch.rand(b, *([1] * (real.dim() - 1)), device=real.device) interp = (eps * real + (1.0 - eps) * fake.detach()).requires_grad_(True) d_interp = discriminator(interp) grads = torch.autograd.grad( outputs=d_interp, inputs=interp, grad_outputs=torch.ones_like(d_interp), create_graph=True, retain_graph=True, )[0] gp = ((grads.norm(2, dim=list(range(1, real.dim()))) - 1) ** 2).mean() return lambda_gp * gp # --------------------------------------------------------------------------- # Class weights and weighted losses # --------------------------------------------------------------------------- def compute_class_weights( label_paths: list[str], n_classes: int, label_mapping: str | None = None, method: str = "inverse_frequency", max_samples: int | None = None, ) -> torch.Tensor: """Compute per-class weights from label volumes. Scans label files to count voxel frequencies per class, then converts to weights. Useful for imbalanced segmentation (e.g., 50-class brain parcellation where small structures are underrepresented). Parameters ---------- label_paths : list of str Paths to label NIfTI/MGZ files. n_classes : int Number of target classes. label_mapping : str or None Label mapping name (e.g., ``"50-class"``) or CSV path. If None, labels are used as-is. method : str ``"inverse_frequency"`` (1/freq, normalized) or ``"median_frequency"`` (median_freq/freq, as in SegNet). max_samples : int or None Limit scanning to this many files (for speed). Returns ------- torch.Tensor Shape ``(n_classes,)`` float tensor of weights. """ import nibabel as nib import numpy as np counts = np.zeros(n_classes, dtype=np.float64) paths = label_paths[:max_samples] if max_samples else label_paths remap_fn = None if label_mapping is not None: from nobrainer.processing.dataset import _load_label_mapping remap_fn = _load_label_mapping(label_mapping) for path in paths: arr = np.asarray(nib.load(path).dataobj, dtype=np.int32) if remap_fn is not None: arr = remap_fn(arr) for c in range(n_classes): counts[c] += (arr == c).sum() # Avoid division by zero counts = np.maximum(counts, 1.0) total = counts.sum() if method == "median_frequency": freqs = counts / total median_freq = np.median(freqs[freqs > 0]) weights = median_freq / freqs else: # inverse_frequency: weight = total / (n_classes * count) weights = total / (n_classes * counts) # Normalize so mean weight = 1 weights = weights / weights.mean() return torch.tensor(weights, dtype=torch.float32) def weighted_cross_entropy( weight: torch.Tensor | None = None, label_smoothing: float = 0.0, ) -> torch.nn.CrossEntropyLoss: """Return a ``CrossEntropyLoss`` with optional per-class weights. Parameters ---------- weight : torch.Tensor or None Per-class weights, shape ``(n_classes,)``. label_smoothing : float Label smoothing factor (default 0). """ return torch.nn.CrossEntropyLoss( weight=weight, label_smoothing=label_smoothing, ) class HammingLoss(torch.nn.Module): """Hamming loss: fraction of misclassified voxels. A differentiable approximation of Hamming distance using soft predictions: ``g·(1-p) + (1-g)·p`` averaged over spatial dims. For use as a loss function with logits, set ``from_logits=True`` to apply softmax first. Parameters ---------- from_logits : bool Apply softmax to predictions (default True). """ def __init__(self, from_logits: bool = True) -> None: super().__init__() self.from_logits = from_logits def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.from_logits: pred = torch.softmax(pred, dim=1) # One-hot encode target if needed if target.ndim == pred.ndim - 1: n_classes = pred.shape[1] target_oh = ( torch.nn.functional.one_hot(target.long(), n_classes) .permute(0, 4, 1, 2, 3) .float() ) else: target_oh = target.float() # Hamming: g*(1-p) + (1-g)*p = fraction of disagreement loss = target_oh * (1 - pred) + (1 - target_oh) * pred return loss.mean() def hamming(from_logits: bool = True) -> HammingLoss: """Return a :class:`HammingLoss` instance.""" return HammingLoss(from_logits=from_logits) class DiceCELoss(torch.nn.Module): """Combined Dice + weighted CrossEntropy loss. Commonly used for imbalanced segmentation tasks. The Dice component is inherently class-balanced; the CE component can use per-class weights. Parameters ---------- weight : torch.Tensor or None Per-class weights for the CE term. dice_weight : float Relative weight of the Dice term (default 1.0). ce_weight : float Relative weight of the CE term (default 1.0). softmax : bool Apply softmax to predictions for the Dice term. label_smoothing : float Label smoothing for the CE term. """ def __init__( self, weight: torch.Tensor | None = None, dice_weight: float = 1.0, ce_weight: float = 1.0, softmax: bool = True, label_smoothing: float = 0.0, ) -> None: super().__init__() self.dice_loss = DiceLoss(softmax=softmax, to_onehot_y=True) self.ce_loss = torch.nn.CrossEntropyLoss( weight=weight, label_smoothing=label_smoothing ) self.dice_weight = dice_weight self.ce_weight = ce_weight def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # Dice expects target with channel dim if target.ndim == pred.ndim - 1: target_dice = target.unsqueeze(1) else: target_dice = target d = self.dice_loss(pred, target_dice) # CE expects target without channel dim, as long if target.ndim == pred.ndim: target_ce = target.squeeze(1) else: target_ce = target if target_ce.dtype != torch.long: target_ce = target_ce.long() ce = self.ce_loss(pred, target_ce) return self.dice_weight * d + self.ce_weight * ce class FocalLoss(torch.nn.Module): """Focal Loss for imbalanced multi-class segmentation. Down-weights well-classified examples and focuses on hard ones. ``FL(p) = -α · (1 - p)^γ · log(p)`` Parameters ---------- gamma : float Focusing parameter (default 2.0). Higher = more focus on hard examples. alpha : torch.Tensor or None Per-class weights. None = uniform. """ def __init__( self, gamma: float = 2.0, alpha: torch.Tensor | None = None, ) -> None: super().__init__() self.gamma = gamma if alpha is not None: self.register_buffer("alpha", alpha) else: self.alpha = None def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.ndim == pred.ndim - 1: pass # expected: target is (B, D, H, W), pred is (B, C, D, H, W) elif target.ndim == pred.ndim and target.shape[1] == 1: target = target.squeeze(1) target = target.long() ce = torch.nn.functional.cross_entropy(pred, target, reduction="none") p = torch.exp(-ce) # probability of correct class focal_weight = (1 - p) ** self.gamma if self.alpha is not None: alpha_t = self.alpha[target] focal_weight = focal_weight * alpha_t return (focal_weight * ce).mean() def focal(gamma: float = 2.0, alpha: torch.Tensor | None = None) -> FocalLoss: """Return a :class:`FocalLoss` instance.""" return FocalLoss(gamma=gamma, alpha=alpha) # --------------------------------------------------------------------------- # Registry # --------------------------------------------------------------------------- _losses = { "dice": dice, "generalized_dice": generalized_dice, "jaccard": jaccard, "tversky": tversky, "elbo": elbo, "wasserstein": wasserstein, "gradient_penalty": gradient_penalty, "hamming": hamming, "focal": focal, "weighted_cross_entropy": weighted_cross_entropy, "dice_ce": DiceCELoss, } def get(name: str): """Return loss factory by name (case-insensitive).""" try: return _losses[name.lower()] except KeyError: avail = ", ".join(_losses) raise ValueError(f"Unknown loss '{name}'. Available: {avail}") from None ================================================ FILE: nobrainer/metrics.py ================================================ """Evaluation metrics for 3-D semantic segmentation (PyTorch / MONAI).""" from __future__ import annotations from monai.metrics import DiceMetric, HausdorffDistanceMetric, MeanIoU import torch # --------------------------------------------------------------------------- # Factory functions returning configured MONAI metric objects # --------------------------------------------------------------------------- def dice_metric( include_background: bool = True, reduction: str = "mean", **kwargs, ) -> DiceMetric: """Return a MONAI ``DiceMetric`` instance. Parameters ---------- include_background : bool Include the background class in the Dice computation. reduction : str Reduction applied over the batch (``"mean"``, ``"sum"``, ``"none"``). """ return DiceMetric( include_background=include_background, reduction=reduction, **kwargs, ) def generalized_dice_metric( include_background: bool = True, reduction: str = "mean", **kwargs, ) -> DiceMetric: """Return a ``DiceMetric`` configured for multi-class (generalised) Dice. MONAI's ``DiceMetric`` computes per-class Dice and averages over classes, which is equivalent to Generalized Dice when class weights are uniform. """ return DiceMetric( include_background=include_background, reduction=reduction, **kwargs, ) def jaccard_metric( include_background: bool = True, reduction: str = "mean", **kwargs, ) -> MeanIoU: """Return a MONAI ``MeanIoU`` (Jaccard) metric instance.""" return MeanIoU( include_background=include_background, reduction=reduction, **kwargs, ) def tversky_metric( include_background: bool = True, reduction: str = "mean", **kwargs, ) -> DiceMetric: """Return a ``DiceMetric`` used as a Tversky metric proxy. Tversky with alpha=beta=0.5 equals Dice. For asymmetric Tversky, compute the Tversky index manually and wrap it in a custom metric. """ return DiceMetric( include_background=include_background, reduction=reduction, **kwargs, ) def hausdorff_metric( include_background: bool = False, distance_metric: str = "euclidean", percentile: float | None = 95.0, directed: bool = False, **kwargs, ) -> HausdorffDistanceMetric: """Return a MONAI ``HausdorffDistanceMetric`` instance. Parameters ---------- include_background : bool Include background class in distance computation. distance_metric : str ``"euclidean"``, ``"chessboard"``, or ``"taxicab"``. percentile : float or None If set, computes the *n*-th percentile Hausdorff distance (e.g. 95 for HD95). ``None`` returns the maximum (HD100). directed : bool Compute directed (asymmetric) Hausdorff distance. """ return HausdorffDistanceMetric( include_background=include_background, distance_metric=distance_metric, percentile=percentile, directed=directed, **kwargs, ) def hamming_metric(reduction: str = "mean") -> "HammingMetric": """Return a Hamming distance metric (fraction of misclassified voxels). Unlike MONAI metrics, this is a simple callable that takes ``(y_pred, y_true)`` integer label tensors and returns the mean fraction of disagreeing voxels. """ return HammingMetric(reduction=reduction) class HammingMetric: """Hamming distance metric: fraction of voxels where prediction != label.""" def __init__(self, reduction: str = "mean") -> None: self.reduction = reduction def __call__( self, y_pred: torch.Tensor, y_true: torch.Tensor, ) -> torch.Tensor: ne = (y_pred != y_true).float() # Average over spatial dims per sample spatial = list(range(1, ne.ndim)) per_sample = ne.mean(dim=spatial) if self.reduction == "mean": return per_sample.mean() if self.reduction == "sum": return per_sample.sum() return per_sample # "none" # --------------------------------------------------------------------------- # Registry # --------------------------------------------------------------------------- _metrics = { "dice": dice_metric, "generalized_dice": generalized_dice_metric, "jaccard": jaccard_metric, "tversky": tversky_metric, "hausdorff": hausdorff_metric, "hamming": hamming_metric, } def get(name: str): """Return metric factory by name (case-insensitive).""" try: return _metrics[name.lower()] except KeyError: avail = ", ".join(_metrics) raise ValueError(f"Unknown metric '{name}'. Available: {avail}") from None ================================================ FILE: nobrainer/models/__init__.py ================================================ """Nobrainer model registry (PyTorch).""" from pprint import pprint from .autoencoder import autoencoder from .highresnet import highresnet from .meshnet import meshnet from .segformer3d import segformer3d from .segmentation import attention_unet, segresnet, swin_unetr, unet, unetr, vnet from .simsiam import simsiam __all__ = ["get", "list_available_models"] # Core models (always available) _models = { "unet": unet, "vnet": vnet, "attention_unet": attention_unet, "unetr": unetr, "meshnet": meshnet, "highresnet": highresnet, "autoencoder": autoencoder, "simsiam": simsiam, "swin_unetr": swin_unetr, "segresnet": segresnet, "segformer3d": segformer3d, } # Optional: Bayesian models (require pyro-ppl) try: from .bayesian import bayesian_meshnet, bayesian_vnet _models["bayesian_vnet"] = bayesian_vnet _models["bayesian_meshnet"] = bayesian_meshnet except ImportError: pass # KWYK MeshNet (VWN-based, no Pyro dependency) from .bayesian.kwyk_meshnet import kwyk_meshnet # noqa: E402 _models["kwyk_meshnet"] = kwyk_meshnet # Optional: Generative models (require pytorch-lightning) try: from .generative import dcgan, progressivegan _models["progressivegan"] = progressivegan _models["dcgan"] = dcgan except ImportError: pass def get(name: str): """Return factory callable for a model by name (case-insensitive). Parameters ---------- name : str Model name. Returns ------- Callable that constructs a ``torch.nn.Module``. """ if not isinstance(name, str): raise ValueError("Model name must be a string.") key = name.lower() if key in _models: return _models[key] # Check if it's an optional model that wasn't loaded optional = { "bayesian_vnet": "pyro-ppl", "bayesian_meshnet": "pyro-ppl", "progressivegan": "pytorch-lightning", "dcgan": "pytorch-lightning", } if key in optional: raise ImportError( f"Model '{name}' requires '{optional[key]}'. " f"Install with: uv pip install {optional[key]}" ) avail = ", ".join(_models) raise ValueError(f"Unknown model '{name}'. Available: {avail}.") def available_models() -> list[str]: return list(_models) def list_available_models() -> None: pprint(available_models()) ================================================ FILE: nobrainer/models/_constants.py ================================================ """Shared constants for nobrainer models.""" from __future__ import annotations # Dilation schedules indexed by receptive field size. # Used by MeshNet, BayesianMeshNet, and KWYKMeshNet. DILATION_SCHEDULES: dict[int, list[int]] = { 37: [1, 1, 1, 2, 4, 8, 1], 67: [1, 1, 2, 4, 8, 16, 1], 129: [1, 2, 4, 8, 16, 32, 1], } ================================================ FILE: nobrainer/models/_utils.py ================================================ """Shared utilities for nobrainer models and training.""" from __future__ import annotations from pathlib import Path import nibabel as nib import numpy as np import torch def unpack_batch( batch: dict | list | tuple, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: """Extract image and label tensors from a batch, move to device. Handles both dict-style (MONAI) and tuple-style (TensorDataset) batches. Squeezes label channel dim and casts to long for CrossEntropyLoss. Parameters ---------- batch : dict, list, or tuple A batch from a DataLoader. device : torch.device Target device. Returns ------- images : torch.Tensor Shape ``(B, C, D, H, W)`` on *device*. labels : torch.Tensor Shape ``(B, D, H, W)`` long dtype on *device*. """ if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) elif isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: raise TypeError(f"Unsupported batch type: {type(batch)}") # Squeeze channel dim from labels if present if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) # Cast float labels to long for CrossEntropyLoss if labels.dtype in (torch.float32, torch.float64): labels = labels.long() return images, labels def load_input( inputs: str | Path | np.ndarray | nib.Nifti1Image, ) -> tuple[np.ndarray, np.ndarray | None]: """Load a 3D volume from various input types. Parameters ---------- inputs : str, Path, ndarray, or Nifti1Image Input volume. Returns ------- arr : np.ndarray 3D array, shape ``(D, H, W)``. affine : np.ndarray or None 4x4 affine matrix (None if input is raw array). """ if isinstance(inputs, (str, Path)): img = nib.load(inputs) return np.asarray(img.dataobj, dtype=np.float32), img.affine elif isinstance(inputs, nib.Nifti1Image): return np.asarray(inputs.dataobj, dtype=np.float32), inputs.affine elif isinstance(inputs, np.ndarray): return inputs.astype(np.float32), None else: raise TypeError(f"Unsupported input type: {type(inputs)}") def model_supports_mc(model: torch.nn.Module) -> bool: """Check if a model supports the ``mc`` keyword argument in forward(). Returns True if the model has a ``supports_mc`` class attribute set to True, or if its forward method accepts an ``mc`` parameter. """ if getattr(model, "supports_mc", False): return True # Check the forward signature import inspect sig = inspect.signature(model.forward) return "mc" in sig.parameters ================================================ FILE: nobrainer/models/autoencoder.py ================================================ """Symmetric 3-D autoencoder (PyTorch). Encodes a 3-D volume into a flat latent vector and reconstructs it via transposed convolutions. """ from __future__ import annotations import math import torch import torch.nn as nn class Autoencoder(nn.Module): """Symmetric 3-D convolutional autoencoder. Dynamically builds encoder depth from the spatial size of the input. Parameters ---------- input_shape : tuple of int Volume shape ``(D, H, W)`` (spatial dims only). in_channels : int Number of input channels (1 for single-modality MRI). encoding_dim : int Size of the flat latent code. n_base_filters : int Base filter count; doubled each encoder level. batchnorm : bool Whether to apply Batch Normalisation in conv blocks. """ def __init__( self, input_shape: tuple[int, int, int] = (64, 64, 64), in_channels: int = 1, encoding_dim: int = 512, n_base_filters: int = 16, batchnorm: bool = True, ) -> None: super().__init__() D = input_shape[0] n_levels = int(math.log2(D)) # Build encoder enc_layers: list[nn.Module] = [] ch_in = in_channels self._enc_channels: list[int] = [] for i in range(n_levels): ch_out = min(n_base_filters * (2**i), encoding_dim) self._enc_channels.append(ch_out) block: list[nn.Module] = [ nn.Conv3d( ch_in, ch_out, kernel_size=4, stride=2, padding=1, bias=not batchnorm, ), ] if batchnorm: block.append(nn.BatchNorm3d(ch_out)) block.append(nn.ReLU(inplace=True)) enc_layers.extend(block) ch_in = ch_out self.encoder_conv = nn.Sequential(*enc_layers) self.encoder_fc = nn.Linear(ch_in, encoding_dim) # Build decoder (mirror of encoder) dec_ch = list(reversed(self._enc_channels)) self.decoder_fc = nn.Linear(encoding_dim, dec_ch[0]) dec_layers: list[nn.Module] = [] all_out = dec_ch[1:] + [in_channels] for i, ch_out in enumerate(all_out): ch_in_d = dec_ch[i] is_last = i == len(all_out) - 1 act_d: nn.Module = ( nn.Sigmoid() if is_last else nn.LeakyReLU(0.2, inplace=True) ) use_bn = batchnorm and not is_last block_d: list[nn.Module] = [ nn.ConvTranspose3d( ch_in_d, ch_out, kernel_size=4, stride=2, padding=1, bias=not use_bn ), ] if use_bn: block_d.append(nn.BatchNorm3d(ch_out)) block_d.append(act_d) dec_layers.extend(block_d) self.decoder_conv = nn.Sequential(*dec_layers) def encode(self, x: torch.Tensor) -> torch.Tensor: h = self.encoder_conv(x) # (N, C, 1, 1, 1) return self.encoder_fc(h.flatten(1)) def decode(self, z: torch.Tensor) -> torch.Tensor: h = self.decoder_fc(z).view(z.size(0), -1, 1, 1, 1) return self.decoder_conv(h) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.decode(self.encode(x)) def autoencoder( input_shape: tuple[int, int, int] = (64, 64, 64), in_channels: int = 1, encoding_dim: int = 512, n_base_filters: int = 16, batchnorm: bool = True, **kwargs, ) -> Autoencoder: """Factory function for :class:`Autoencoder`.""" return Autoencoder( input_shape=input_shape, in_channels=in_channels, encoding_dim=encoding_dim, n_base_filters=n_base_filters, batchnorm=batchnorm, ) __all__ = ["Autoencoder", "autoencoder"] ================================================ FILE: nobrainer/models/bayesian/__init__.py ================================================ """Bayesian model sub-package. Two flavours of Bayesian convolution are provided: * **Bayes-by-backprop** (``BayesianConv3d``, ``BayesianMeshNet``) — Pyro-based, weight uncertainty via learned mu/sigma, supports standard_normal/laplace priors. * **Variational Weight Normalization** (``VWNConv3d``, ``KWYKMeshNet``) — matches the original kwyk architecture (McClure et al. 2019) with weight normalization, local reparameterization, and Bernoulli or Concrete dropout. """ from .bayesian_meshnet import BayesianMeshNet, bayesian_meshnet from .bayesian_vnet import BayesianVNet, bayesian_vnet from .kwyk_meshnet import KWYKMeshNet, kwyk_meshnet from .layers import BayesianConv3d, BayesianLinear from .utils import accumulate_kl from .vwn_layers import ConcreteDropout3d, FFGConv3d, VWNConv3d __all__ = [ "BayesianConv3d", "BayesianLinear", "BayesianMeshNet", "BayesianVNet", "ConcreteDropout3d", "FFGConv3d", "KWYKMeshNet", "VWNConv3d", "accumulate_kl", "bayesian_meshnet", "bayesian_vnet", "kwyk_meshnet", ] ================================================ FILE: nobrainer/models/bayesian/bayesian_meshnet.py ================================================ """Bayesian MeshNet: dilated-convolution segmentation with weight uncertainty. Replaces every ``nn.Conv3d`` in the 7-layer dilated architecture with :class:`~nobrainer.models.bayesian.layers.BayesianConv3d`. Reference --------- Fedorov A. et al., "End-to-end learning of brain tissue segmentation from imperfect labeling", IJCNN 2017. arXiv:1612.00940. """ from __future__ import annotations from pyro.nn import PyroModule import torch import torch.nn as nn import torch.nn.functional as F from nobrainer.models._constants import ( # noqa: E501 DILATION_SCHEDULES as _DILATION_SCHEDULES, ) from .layers import BayesianConv3d class _BayesConvBNActDrop(PyroModule): """Single dilated Bayesian conv layer with BN + ELU/ReLU + spatial dropout.""" def __init__( self, in_ch: int, out_ch: int, dilation: int, activation: str, dropout_rate: float, prior_type: str, **sas_kwargs, ) -> None: super().__init__() padding = dilation # same-size output for 3×3×3 kernel self.conv = BayesianConv3d( in_ch, out_ch, kernel_size=3, padding=padding, dilation=dilation, bias=False, prior_type=prior_type, **sas_kwargs, ) self.bn = nn.BatchNorm3d(out_ch) self.act_fn = {"relu": F.relu, "elu": F.elu}[activation.lower()] self.dropout = nn.Dropout3d(p=dropout_rate) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.act_fn(self.bn(self.conv(x)))) class BayesianMeshNet(PyroModule): """3-D MeshNet with Bayesian convolutional layers. Identical dilated-convolution schedule as :class:`~nobrainer.models.meshnet.MeshNet` but all ``nn.Conv3d`` layers are replaced with :class:`BayesianConv3d`. Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input image channels. filters : int Feature-map count in all hidden layers. receptive_field : int One of ``37``, ``67``, ``129`` — selects the dilation schedule. activation : str ``"relu"`` or ``"elu"``. dropout_rate : float Spatial dropout probability (0 = disabled). prior_type : str ``"standard_normal"``, ``"laplace"``, or ``"spike_and_slab"``. kl_weight : float Scalar applied to the summed KL when computing the ELBO. Stored as an attribute; not used internally during forward. spike_sigma : float Spike component σ for spike-and-slab prior (default 0.001). slab_sigma : float Slab component σ for spike-and-slab prior (default 1.0). prior_pi : float Prior probability of the spike component (default 0.5). """ def __init__( self, n_classes: int = 1, in_channels: int = 1, filters: int = 71, receptive_field: int = 67, activation: str = "relu", dropout_rate: float = 0.25, prior_type: str = "standard_normal", kl_weight: float = 1.0, spike_sigma: float = 0.001, slab_sigma: float = 1.0, prior_pi: float = 0.5, ) -> None: super().__init__() if receptive_field not in _DILATION_SCHEDULES: raise ValueError( f"receptive_field must be one of {list(_DILATION_SCHEDULES)}, " f"got {receptive_field}" ) self.kl_weight = kl_weight self.prior_type = prior_type dilations = _DILATION_SCHEDULES[receptive_field] self._n_layers = len(dilations) # Extra kwargs for spike-and-slab layers sas_kwargs = {} if prior_type == "spike_and_slab": sas_kwargs = { "spike_sigma": spike_sigma, "slab_sigma": slab_sigma, "prior_pi": prior_pi, } # Register each Bayesian layer as a named attribute so Pyro assigns # unique sample site names (nn.ModuleList does not propagate names). for i, dil in enumerate(dilations): in_ch = in_channels if i == 0 else filters layer = _BayesConvBNActDrop( in_ch, filters, dil, activation, dropout_rate, prior_type, **sas_kwargs, ) setattr(self, f"layer_{i}", layer) # Final 1×1×1 classifier — deterministic self.classifier = nn.Conv3d(filters, n_classes, kernel_size=1) supports_mc = True def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: h = x for i in range(self._n_layers): h = getattr(self, f"layer_{i}")(h) return self.classifier(h) def bayesian_meshnet( n_classes: int = 1, in_channels: int = 1, filters: int = 71, receptive_field: int = 67, activation: str = "relu", dropout_rate: float = 0.25, prior_type: str = "standard_normal", kl_weight: float = 1.0, spike_sigma: float = 0.001, slab_sigma: float = 1.0, prior_pi: float = 0.5, **kwargs, ) -> BayesianMeshNet: """Factory function for :class:`BayesianMeshNet`.""" return BayesianMeshNet( n_classes=n_classes, in_channels=in_channels, filters=filters, receptive_field=receptive_field, activation=activation, dropout_rate=dropout_rate, prior_type=prior_type, kl_weight=kl_weight, spike_sigma=spike_sigma, slab_sigma=slab_sigma, prior_pi=prior_pi, ) __all__ = ["BayesianMeshNet", "bayesian_meshnet"] ================================================ FILE: nobrainer/models/bayesian/bayesian_vnet.py ================================================ """Bayesian V-Net: encoder-decoder segmentation with weight uncertainty. Replaces the standard ``nn.Conv3d`` convolutions with :class:`~nobrainer.models.bayesian.layers.BayesianConv3d` (mean-field variational inference via Pyro), preserving the residual encoder-decoder architecture of V-Net. Reference --------- Milletari F. et al., "V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation", 3DV 2016. arXiv:1606.04797. """ from __future__ import annotations from pyro.nn import PyroModule import torch import torch.nn as nn import torch.nn.functional as F from .layers import BayesianConv3d class _BayesResBlock(PyroModule): """Two-layer residual block with BayesianConv3d and a skip connection.""" def __init__(self, channels: int, prior_type: str = "standard_normal") -> None: super().__init__() self.conv1 = BayesianConv3d( channels, channels, kernel_size=3, padding=1, prior_type=prior_type ) self.conv2 = BayesianConv3d( channels, channels, kernel_size=3, padding=1, prior_type=prior_type ) self.bn1 = nn.BatchNorm3d(channels) self.bn2 = nn.BatchNorm3d(channels) def forward(self, x: torch.Tensor) -> torch.Tensor: h = F.elu(self.bn1(self.conv1(x))) h = self.bn2(self.conv2(h)) return F.elu(h + x) class _EncoderBlock(PyroModule): """One encoder level: project channels → residual block → max-pool.""" def __init__( self, in_ch: int, out_ch: int, prior_type: str = "standard_normal", ) -> None: super().__init__() self.proj = BayesianConv3d(in_ch, out_ch, kernel_size=1, prior_type=prior_type) self.bn_proj = nn.BatchNorm3d(out_ch) self.res = _BayesResBlock(out_ch, prior_type=prior_type) self.pool = nn.MaxPool3d(kernel_size=2) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: h = F.elu(self.bn_proj(self.proj(x))) h = self.res(h) return self.pool(h), h # (down-sampled, skip) class _DecoderBlock(PyroModule): """One decoder level: up-sample → concat skip → project → residual block.""" def __init__( self, in_ch: int, skip_ch: int, out_ch: int, prior_type: str = "standard_normal", ) -> None: super().__init__() self.upsample = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=2, stride=2) self.proj = BayesianConv3d( out_ch + skip_ch, out_ch, kernel_size=1, prior_type=prior_type ) self.bn_proj = nn.BatchNorm3d(out_ch) self.res = _BayesResBlock(out_ch, prior_type=prior_type) def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: h = self.upsample(x) if h.shape != skip.shape: h = F.interpolate( h, size=skip.shape[2:], mode="trilinear", align_corners=False ) h = torch.cat([h, skip], dim=1) h = F.elu(self.bn_proj(self.proj(h))) return self.res(h) class BayesianVNet(PyroModule): """3-D V-Net with Bayesian convolutional layers. All ``nn.Conv3d`` layers in the encoder and decoder are replaced with :class:`BayesianConv3d`. Upsampling transposed convolutions remain deterministic. Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input image channels. base_filters : int Feature-map count at the first encoder level (doubles each level). levels : int Number of encoder/decoder levels (default 4). prior_type : str ``"standard_normal"`` or ``"laplace"`` — forwarded to Bayesian layers. kl_weight : float Scalar applied to the summed KL divergence when computing the ELBO. Stored as an attribute; not used internally during forward. """ def __init__( self, n_classes: int = 1, in_channels: int = 1, base_filters: int = 16, levels: int = 4, prior_type: str = "standard_normal", kl_weight: float = 1.0, ) -> None: super().__init__() self.kl_weight = kl_weight self._levels = levels ch = [base_filters * (2**i) for i in range(levels)] # Input projection self.input_proj = BayesianConv3d( in_channels, ch[0], kernel_size=3, padding=1, prior_type=prior_type ) self.input_bn = nn.BatchNorm3d(ch[0]) # Encoder — registered as individually named attributes so Pyro can # assign unique site names (nn.ModuleList does not propagate names). # encoder_i: ch[i] → ch[i+1]; skip tensor has ch[i+1] channels. for i in range(levels - 1): enc = _EncoderBlock(ch[i], ch[i + 1], prior_type) setattr(self, f"encoder_{i}", enc) # Bottom residual block (no pooling) self.bottom_res = _BayesResBlock(ch[-1], prior_type=prior_type) # Decoder — decoder_i processes the stage closest to the bottom first. # decoder_i: in_ch = ch[L-1-i], skip_ch = ch[L-1-i], out_ch = ch[L-2-i] # (upsampled out_ch channels are cat'd with skip_ch to give # out_ch + skip_ch channels before the projection layer) L = levels for i in range(L - 1): in_ch = ch[L - 1 - i] skip_ch = ch[L - 1 - i] # skip from encoder_{L-2-i} has ch[L-1-i] chans out_ch = ch[L - 2 - i] dec = _DecoderBlock(in_ch, skip_ch, out_ch, prior_type) setattr(self, f"decoder_{i}", dec) # Final 1×1×1 classifier — deterministic self.classifier = nn.Conv3d(ch[0], n_classes, kernel_size=1) supports_mc = True def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: h = F.elu(self.input_bn(self.input_proj(x))) skips: list[torch.Tensor] = [] for i in range(self._levels - 1): enc = getattr(self, f"encoder_{i}") h, skip = enc(h) skips.append(skip) h = self.bottom_res(h) for i in range(self._levels - 1): dec = getattr(self, f"decoder_{i}") skip = skips[self._levels - 2 - i] h = dec(h, skip) return self.classifier(h) def bayesian_vnet( n_classes: int = 1, in_channels: int = 1, base_filters: int = 16, levels: int = 4, prior_type: str = "standard_normal", kl_weight: float = 1.0, **kwargs, ) -> BayesianVNet: """Factory function for :class:`BayesianVNet`.""" return BayesianVNet( n_classes=n_classes, in_channels=in_channels, base_filters=base_filters, levels=levels, prior_type=prior_type, kl_weight=kl_weight, ) __all__ = ["BayesianVNet", "bayesian_vnet"] ================================================ FILE: nobrainer/models/bayesian/kwyk_meshnet.py ================================================ """KWYK MeshNet variants — matching McClure et al. (2019) architecture. All three kwyk models use Fully Factorized Gaussian (FFG) convolutions with learned per-weight μ and σ, and the local reparameterization trick (Kingma et al. 2015). They differ in the dropout layer: * **bwn** / **bwn_multi**: FFG conv + Bernoulli dropout (``bwn`` disables dropout at inference; ``bwn_multi`` keeps it on) * **bvwn_multi_prior**: FFG conv + Concrete dropout (learned per-filter rate) This is the "spike-and-slab dropout" (SSD) model from the paper. Reference --------- McClure P. et al., "Knowing What You Know in Brain Segmentation Using Bayesian Deep Neural Networks", Front. Neuroinform. 2019. """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from nobrainer.models._constants import ( # noqa: E501 DILATION_SCHEDULES as _DILATION_SCHEDULES, ) from .vwn_layers import ConcreteDropout3d, FFGConv3d class _VWNLayerBernoulli(nn.Module): """VWN conv + ReLU + Bernoulli dropout (bwn / bwn_multi).""" def __init__( self, in_ch: int, out_ch: int, dilation: int, dropout_rate: float, sigma_init: float, ) -> None: super().__init__() self.conv = FFGConv3d( in_ch, out_ch, kernel_size=3, padding=dilation, dilation=dilation, bias=False, sigma_init=sigma_init, ) self.dropout = nn.Dropout3d(p=dropout_rate) def forward( self, x: torch.Tensor, mc_vwn: bool = True, mc_dropout: bool = True, ) -> torch.Tensor: # Original TF order: conv -> dropout -> relu (meshnetbwn.py:59-61) h = self.conv(x, mc=mc_vwn) if mc_dropout: h = self.dropout(h) return F.relu(h) class _VWNLayerConcrete(nn.Module): """VWN conv + ReLU + Concrete dropout (bvwn_multi_prior).""" def __init__( self, in_ch: int, out_ch: int, dilation: int, sigma_init: float, concrete_temperature: float = 0.02, concrete_init_p: float = 0.9, ) -> None: super().__init__() self.conv = FFGConv3d( in_ch, out_ch, kernel_size=3, padding=dilation, dilation=dilation, bias=False, sigma_init=sigma_init, ) self.dropout = ConcreteDropout3d( out_ch, temperature=concrete_temperature, init_p=concrete_init_p, ) def forward( self, x: torch.Tensor, mc_vwn: bool = True, mc_dropout: bool = True, ) -> torch.Tensor: # Original TF order: conv -> dropout -> relu (meshnetbvwn.py:54-55) h = self.conv(x, mc=mc_vwn) h = self.dropout(h, mc=mc_dropout) return F.relu(h) class KWYKMeshNet(nn.Module): """KWYK MeshNet with variational weight normalization. This is the architecture used in McClure et al. (2019). All layers use VWN convolutions; the ``dropout_type`` parameter selects between Bernoulli (``"bernoulli"``) and Concrete (``"concrete"``) dropout. Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input image channels. filters : int Feature-map count in all hidden layers. receptive_field : int One of ``37``, ``67``, ``129`` — selects the dilation schedule. dropout_type : str ``"bernoulli"`` for bwn/bwn_multi, ``"concrete"`` for bvwn_multi_prior. dropout_rate : float For Bernoulli dropout (ignored for concrete). sigma_init : float Initial value for weight sigma (default 1e-4, matching kwyk). concrete_temperature : float Temperature for concrete dropout (default 0.02). concrete_init_p : float Initial dropout probability for concrete dropout (default 0.9). """ def __init__( self, n_classes: int = 1, in_channels: int = 1, filters: int = 71, receptive_field: int = 67, dropout_type: str = "bernoulli", dropout_rate: float = 0.25, sigma_init: float = 1e-4, concrete_temperature: float = 0.02, concrete_init_p: float = 0.9, ) -> None: super().__init__() if receptive_field not in _DILATION_SCHEDULES: raise ValueError( f"receptive_field must be one of {list(_DILATION_SCHEDULES)}, " f"got {receptive_field}" ) self.dropout_type = dropout_type dilations = _DILATION_SCHEDULES[receptive_field] self._n_layers = len(dilations) for i, dil in enumerate(dilations): in_ch = in_channels if i == 0 else filters if dropout_type == "concrete": layer = _VWNLayerConcrete( in_ch, filters, dil, sigma_init, concrete_temperature, concrete_init_p, ) else: layer = _VWNLayerBernoulli( in_ch, filters, dil, dropout_rate, sigma_init, ) setattr(self, f"layer_{i}", layer) self.classifier = nn.Conv3d(filters, n_classes, kernel_size=1) def forward( self, x: torch.Tensor, mc: bool | None = None, mc_vwn: bool = True, mc_dropout: bool = True, ) -> torch.Tensor: """Forward pass. Parameters ---------- x : Tensor Input ``(B, 1, D, H, W)``. mc : bool or None Legacy convenience flag. If provided, sets both ``mc_vwn`` and ``mc_dropout`` to the same value (backward compat). mc_vwn : bool If True, use stochastic VWN reparameterization. If False, use deterministic mean weights only. mc_dropout : bool If True, apply stochastic dropout. If False, skip dropout (Bernoulli) or use expectation (Concrete). Note ---- The original TF bwn model trains with ``mc_vwn=False, mc_dropout=True`` (deterministic weights + stochastic dropout). """ if mc is not None: mc_vwn = mc mc_dropout = mc h = x for i in range(self._n_layers): h = getattr(self, f"layer_{i}")(h, mc_vwn=mc_vwn, mc_dropout=mc_dropout) return self.classifier(h) def kl_divergence(self) -> torch.Tensor: """Sum KL divergence from all VWN conv layers.""" kl = torch.tensor(0.0, device=next(self.parameters()).device) for m in self.modules(): if isinstance(m, FFGConv3d): kl = kl + m.kl return kl def concrete_regularization(self) -> torch.Tensor: """Sum concrete dropout regularization (0 for bernoulli models).""" reg = torch.tensor(0.0, device=next(self.parameters()).device) for m in self.modules(): if isinstance(m, ConcreteDropout3d): reg = reg + m.regularization() return reg def kwyk_meshnet( n_classes: int = 1, in_channels: int = 1, filters: int = 71, receptive_field: int = 67, dropout_type: str = "bernoulli", dropout_rate: float = 0.25, sigma_init: float = 1e-4, concrete_temperature: float = 0.02, concrete_init_p: float = 0.9, **kwargs, ) -> KWYKMeshNet: """Factory function for :class:`KWYKMeshNet`.""" return KWYKMeshNet( n_classes=n_classes, in_channels=in_channels, filters=filters, receptive_field=receptive_field, dropout_type=dropout_type, dropout_rate=dropout_rate, sigma_init=sigma_init, concrete_temperature=concrete_temperature, concrete_init_p=concrete_init_p, ) __all__ = ["KWYKMeshNet", "kwyk_meshnet"] ================================================ FILE: nobrainer/models/bayesian/layers.py ================================================ """Bayesian convolutional and linear layers as Pyro modules. Both ``BayesianConv3d`` and ``BayesianLinear`` implement weight uncertainty by maintaining learnable ``weight_mu`` and ``weight_sigma`` parameters. During each stochastic forward pass they sample a weight matrix from ``Normal(weight_mu, softplus(weight_sigma))`` and accumulate the KL divergence against the prior into ``self.kl``. Three prior types are supported (matching the kwyk study variants): * ``"standard_normal"`` — N(0, 1) prior, standard Bayes-by-backprop. * ``"laplace"`` — tight Normal N(0, 0.1) approximation of a Laplace prior. * ``"spike_and_slab"`` — mixture prior ``π·N(0, σ₁) + (1-π)·N(0, σ₂)`` where σ₁ (spike) is small and σ₂ (slab) is large. Each weight also learns a log-odds ``z_logit`` controlling how much mass is on the spike vs slab, implementing variational spike-and-slab dropout (SSD) as in McClure et al. (2019). """ from __future__ import annotations import math import pyro import pyro.distributions as dist from pyro.nn import PyroModule, PyroParam import torch from torch.distributions import constraints import torch.nn.functional as F # --------------------------------------------------------------------------- # KL helpers # --------------------------------------------------------------------------- def _kl_normal_normal( mu: torch.Tensor, sigma: torch.Tensor, prior_mu: float, prior_sigma: float, ) -> torch.Tensor: """Analytic KL(N(mu, sigma) || N(prior_mu, prior_sigma)).""" return ( torch.log(prior_sigma / (sigma + 1e-8)) + (sigma**2 + (mu - prior_mu) ** 2) / (2 * prior_sigma**2) - 0.5 ).sum() def _kl_spike_and_slab( mu: torch.Tensor, sigma: torch.Tensor, z_logit: torch.Tensor, spike_sigma: float, slab_sigma: float, prior_pi: float, ) -> torch.Tensor: """KL divergence for spike-and-slab variational posterior. The variational posterior is: q(w, z) = Bernoulli(z; sigmoid(z_logit)) · N(w; mu, sigma) The prior is: p(w, z) = (pi·N(0, spike_sigma) + (1-pi)·N(0, slab_sigma)) We use the closed-form approximation from Louizos et al. (2017) and the practical version used in the kwyk spike-and-slab dropout. """ z = torch.sigmoid(z_logit) # Log-likelihood under spike and slab components log_spike = -0.5 * math.log(2 * math.pi * spike_sigma**2) - ( mu**2 + sigma**2 ) / (2 * spike_sigma**2) log_slab = -0.5 * math.log(2 * math.pi * slab_sigma**2) - ( mu**2 + sigma**2 ) / (2 * slab_sigma**2) # Entropy of the Bernoulli gate entropy_z = -(z * torch.log(z + 1e-8) + (1 - z) * torch.log(1 - z + 1e-8)) # KL = E_q[log q - log p] # log q(w|z=slab) - log p(w) where p is the mixture kl_per_weight = ( z * (-0.5 * torch.log(2 * math.pi * sigma**2 + 1e-8) - 0.5 - log_slab) + (1 - z) * (-log_spike) - entropy_z + z * math.log(1 - prior_pi + 1e-8) + (1 - z) * math.log(prior_pi + 1e-8) ) return kl_per_weight.sum() # --------------------------------------------------------------------------- # Bayesian layers # --------------------------------------------------------------------------- class BayesianConv3d(PyroModule): """3-D convolution with learnable weight distribution (Pyro). Parameters ---------- in_channels, out_channels : int Standard convolution channel counts. kernel_size : int Cubic kernel side length. stride, padding, dilation : int Standard ``nn.Conv3d`` arguments. bias : bool Whether to include a deterministic bias term. prior_type : str ``"standard_normal"`` (σ=1), ``"laplace"`` (tight Normal σ=0.1), or ``"spike_and_slab"`` (mixture prior with learnable gates). spike_sigma : float Spike component σ for spike-and-slab prior (default 0.001). slab_sigma : float Slab component σ for spike-and-slab prior (default 1.0). prior_pi : float Prior probability of the spike component (default 0.5). """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0, dilation: int = 1, bias: bool = True, prior_type: str = "standard_normal", spike_sigma: float = 0.001, slab_sigma: float = 1.0, prior_pi: float = 0.5, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.prior_type = prior_type self.spike_sigma = spike_sigma self.slab_sigma = slab_sigma self.prior_pi = prior_pi weight_shape = ( out_channels, in_channels, kernel_size, kernel_size, kernel_size, ) # Kaiming init for mu fan_in = in_channels * kernel_size**3 std_init = math.sqrt(2.0 / fan_in) self.weight_mu = PyroParam( torch.zeros(weight_shape).normal_(0, std_init), constraint=constraints.real, ) self.weight_rho = PyroParam( torch.full(weight_shape, -3.0), # softplus(-3) ≈ 0.05 constraint=constraints.real, ) # Spike-and-slab gate logits (one per weight) if prior_type == "spike_and_slab": self.z_logit = PyroParam( torch.full(weight_shape, 2.0), # sigmoid(2) ≈ 0.88 → mostly slab constraint=constraints.real, ) if bias: self.bias_mu = PyroParam( torch.zeros(out_channels), constraint=constraints.real ) self.bias_rho = PyroParam( torch.full((out_channels,), -3.0), constraint=constraints.real ) else: self.bias_mu = None self.bias_rho = None if prior_type == "standard_normal": self.prior_sigma = 1.0 elif prior_type == "laplace": self.prior_sigma = 0.1 else: self.prior_sigma = slab_sigma # used as fallback only self.kl: torch.Tensor = torch.tensor(0.0) @property def weight_sigma(self) -> torch.Tensor: return F.softplus(self.weight_rho) def forward(self, x: torch.Tensor) -> torch.Tensor: weight = pyro.sample( f"{self._pyro_name}.weight", dist.Normal(self.weight_mu, self.weight_sigma + 1e-8).to_event( self.weight_mu.dim() ), ) if self.prior_type == "spike_and_slab": # Apply spike-and-slab mask: sample Bernoulli gate, mask weights z_prob = torch.sigmoid(self.z_logit) z_mask = torch.bernoulli(z_prob) weight = weight * z_mask self.kl = _kl_spike_and_slab( self.weight_mu, self.weight_sigma, self.z_logit, self.spike_sigma, self.slab_sigma, self.prior_pi, ) else: self.kl = _kl_normal_normal( self.weight_mu, self.weight_sigma, 0.0, self.prior_sigma ) bias = None if self.bias_mu is not None: bias_sigma = F.softplus(self.bias_rho) bias = pyro.sample( f"{self._pyro_name}.bias", dist.Normal(self.bias_mu, bias_sigma + 1e-8).to_event(1), ) self.kl = self.kl + _kl_normal_normal( self.bias_mu, bias_sigma, 0.0, self.prior_sigma ) return F.conv3d(x, weight, bias, self.stride, self.padding, self.dilation) class BayesianLinear(PyroModule): """Fully-connected layer with learnable weight distribution (Pyro). Parameters ---------- in_features, out_features : int Standard ``nn.Linear`` dimensions. bias : bool Whether to include a deterministic bias term. prior_type : str ``"standard_normal"``, ``"laplace"``, or ``"spike_and_slab"``. spike_sigma : float Spike component σ for spike-and-slab prior (default 0.001). slab_sigma : float Slab component σ for spike-and-slab prior (default 1.0). prior_pi : float Prior probability of the spike component (default 0.5). """ def __init__( self, in_features: int, out_features: int, bias: bool = True, prior_type: str = "standard_normal", spike_sigma: float = 0.001, slab_sigma: float = 1.0, prior_pi: float = 0.5, ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.prior_type = prior_type self.spike_sigma = spike_sigma self.slab_sigma = slab_sigma self.prior_pi = prior_pi std_init = math.sqrt(2.0 / in_features) self.weight_mu = PyroParam( torch.zeros(out_features, in_features).normal_(0, std_init), constraint=constraints.real, ) self.weight_rho = PyroParam( torch.full((out_features, in_features), -3.0), constraint=constraints.real, ) if prior_type == "spike_and_slab": self.z_logit = PyroParam( torch.full((out_features, in_features), 2.0), constraint=constraints.real, ) if bias: self.bias_mu = PyroParam( torch.zeros(out_features), constraint=constraints.real ) self.bias_rho = PyroParam( torch.full((out_features,), -3.0), constraint=constraints.real ) else: self.bias_mu = None self.bias_rho = None if prior_type == "standard_normal": self.prior_sigma = 1.0 elif prior_type == "laplace": self.prior_sigma = 0.1 else: self.prior_sigma = slab_sigma self.kl: torch.Tensor = torch.tensor(0.0) @property def weight_sigma(self) -> torch.Tensor: return F.softplus(self.weight_rho) def forward(self, x: torch.Tensor) -> torch.Tensor: weight = pyro.sample( f"{self._pyro_name}.weight", dist.Normal(self.weight_mu, self.weight_sigma + 1e-8).to_event(2), ) if self.prior_type == "spike_and_slab": z_prob = torch.sigmoid(self.z_logit) z_mask = torch.bernoulli(z_prob) weight = weight * z_mask self.kl = _kl_spike_and_slab( self.weight_mu, self.weight_sigma, self.z_logit, self.spike_sigma, self.slab_sigma, self.prior_pi, ) else: self.kl = _kl_normal_normal( self.weight_mu, self.weight_sigma, 0.0, self.prior_sigma ) bias = None if self.bias_mu is not None: bias_sigma = F.softplus(self.bias_rho) bias = pyro.sample( f"{self._pyro_name}.bias", dist.Normal(self.bias_mu, bias_sigma + 1e-8).to_event(1), ) self.kl = self.kl + _kl_normal_normal( self.bias_mu, bias_sigma, 0.0, self.prior_sigma ) return F.linear(x, weight, bias) ================================================ FILE: nobrainer/models/bayesian/utils.py ================================================ """Utility functions for Bayesian models.""" from __future__ import annotations import torch from .layers import BayesianConv3d, BayesianLinear from .vwn_layers import ConcreteDropout3d, FFGConv3d def accumulate_kl(model: torch.nn.Module) -> torch.Tensor: """Sum KL divergence from all Bayesian layers in ``model``. Works with both Pyro-based models (BayesianConv3d, BayesianLinear) and VWN/FFG models (FFGConv3d, ConcreteDropout3d). Parameters ---------- model : nn.Module A model containing one or more Bayesian layers. Returns ------- torch.Tensor Scalar KL sum. """ kl = torch.tensor(0.0) for m in model.modules(): # Pyro-based layers if isinstance(m, (BayesianConv3d, BayesianLinear)): kl = kl + m.kl # VWN/FFG layers elif isinstance(m, FFGConv3d): kl = kl + m.kl # Concrete dropout regularization elif isinstance(m, ConcreteDropout3d): kl = kl + m.kl_divergence() return kl ================================================ FILE: nobrainer/models/bayesian/vwn_layers.py ================================================ """Fully Factorized Gaussian (FFG) layers with local reparameterization. These layers implement the convolution used in McClure et al. (2019), Section 2.2.3.2 ("Spike-and-Slab Dropout with Learned Model Uncertainty"): * Each weight has learnable mean ``μ_{f,t}`` and std ``σ_{f,t}`` * **Local reparameterization trick** (Kingma et al. 2015): instead of sampling weights, the output distribution is computed directly: ``output ~ N(conv(x, μ), conv(x², σ²))`` (Eqs. 12-14) * The **spike-and-slab dropout (SSD)** model combines this with **concrete dropout** (Gal et al. 2017): ``output_v = b_f · (g_f * h)_v`` where ``b_f`` is a per-filter concrete dropout mask (Eq. 11). The KL divergence has two terms (Eq. 16 in paper): 1. Bernoulli KL for concrete dropout gates (Eq. 17) 2. Gaussian KL for each weight: ``KL(N(μ,σ) || N(μ_prior, σ_prior))`` (Eq. 18) Prior parameters from the paper: ``p_prior=0.5, μ_prior=0, σ_prior=0.1`` Two dropout variants: * **Bernoulli dropout** — standard ``nn.Dropout3d``, fixed rate (BD model) * **Concrete dropout** — per-filter learnable drop rate (SSD model) """ from __future__ import annotations import math import torch import torch.nn as nn import torch.nn.functional as F class FFGConv3d(nn.Module): """3-D convolution with Variational Weight Normalization + learned sigma. Verified against the actual kwyk trained model (``neuronets/kwyk:latest``). Each layer stores: ``v``, ``g``, ``kernel_a``, ``bias_m``, ``bias_a``. **Mean weights** use weight normalization (Salimans & Kingma 2016): ``kernel_m = g · v / ||v||`` **Sigma** is learned per weight: ``kernel_sigma = |kernel_a|`` During stochastic forward passes (``mc=True``), the **local reparameterization trick** (Kingma et al. 2015, Eqs. 12-14 in McClure et al. 2019) computes the output distribution directly: ``μ* = conv(x, kernel_m)`` ``σ*² = conv(x², kernel_sigma²)`` ``output = μ* + σ* · ε, ε ~ N(0, 1)`` In deterministic mode (``mc=False``) only the mean path is used. Parameters ---------- in_channels, out_channels : int Standard convolution channel counts. kernel_size : int Cubic kernel side length. stride, padding, dilation : int Standard ``nn.Conv3d`` arguments. bias : bool Whether to include a bias term (with its own sigma). sigma_init : float Initial value for ``|kernel_a|`` (default 1e-4, matching kwyk). prior_mu : float Prior mean for KL (default 0.0). prior_sigma : float Prior std for KL (default 0.1, matching paper Eq. 18). """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0, dilation: int = 1, bias: bool = True, sigma_init: float = 1e-4, prior_mu: float = 0.0, prior_sigma: float = 0.1, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.prior_mu = prior_mu self.prior_sigma = prior_sigma k = kernel_size weight_shape = (out_channels, in_channels, k, k, k) g_shape = (out_channels, 1, 1, 1, 1) # Weight normalization: kernel_m = g * v / ||v|| self.v = nn.Parameter(torch.empty(weight_shape)) nn.init.kaiming_normal_(self.v, mode="fan_in", nonlinearity="relu") self.g = nn.Parameter(torch.full(g_shape, math.sqrt(2.0))) # Learned sigma: kernel_sigma = |kernel_a| self.kernel_a = nn.Parameter(torch.full(weight_shape, sigma_init)) if bias: self.bias_m = nn.Parameter(torch.zeros(out_channels)) self.bias_a = nn.Parameter(torch.full((out_channels,), sigma_init)) else: self.register_parameter("bias_m", None) self.register_parameter("bias_a", None) # Accumulated KL (updated each forward pass) self.kl: torch.Tensor = torch.tensor(0.0) @property def kernel_m(self) -> torch.Tensor: """Mean weight: ``g · v / ||v||``.""" v_norm = F.normalize(self.v.flatten(1), dim=1).view_as(self.v) return self.g * v_norm @property def weight_sigma(self) -> torch.Tensor: """Weight std: ``|kernel_a|``.""" return torch.abs(self.kernel_a) def forward(self, x: torch.Tensor, mc: bool = True) -> torch.Tensor: """Forward pass with optional stochastic sampling.""" km = self.kernel_m out_mean = F.conv3d( x, km, self.bias_m, self.stride, self.padding, self.dilation ) if not mc: return out_mean # Local reparameterization trick (Eqs. 12-14) sigma = self.weight_sigma out_var = F.conv3d( x.pow(2), sigma.pow(2), None, self.stride, self.padding, self.dilation ) if self.bias_a is not None: bias_sigma = torch.abs(self.bias_a) out_var = out_var + bias_sigma.pow(2).view(1, -1, 1, 1, 1) noise = torch.randn_like(out_mean) out = out_mean + torch.sqrt(out_var + 1e-8) * noise # KL(N(kernel_m, sigma) || N(prior_mu, prior_sigma)) — Eq. 18 self.kl = ( torch.log(self.prior_sigma / (sigma + 1e-8)) + (sigma.pow(2) + (km - self.prior_mu).pow(2)) / (2 * self.prior_sigma**2) - 0.5 ).sum() return out # Backward-compatible alias VWNConv3d = FFGConv3d class ConcreteDropout3d(nn.Module): """Concrete dropout (Gal et al. 2017) with per-filter learnable rate. Instead of a fixed dropout probability, each output filter learns its own drop rate ``p`` via a continuous relaxation of Bernoulli sampling (Eq. 10 in McClure et al. 2019). Parameters ---------- n_filters : int Number of filters (one ``p`` per filter). temperature : float Concrete distribution temperature (default 0.02, matching paper). init_p : float Initial dropout probability (default 0.9, matching kwyk code). prior_p : float Prior dropout probability for KL (default 0.5, matching paper). """ def __init__( self, n_filters: int, temperature: float = 0.02, init_p: float = 0.9, prior_p: float = 0.5, ) -> None: super().__init__() self.temperature = temperature self.prior_p = prior_p # Store as raw logit; p = sigmoid(p_logit) to keep in (0, 1) init_logit = math.log(init_p / (1 - init_p + 1e-8)) self.p_logit = nn.Parameter(torch.full((n_filters,), init_logit)) @property def p(self) -> torch.Tensor: """Per-filter dropout probabilities, clamped to [0.05, 0.95].""" return torch.sigmoid(self.p_logit).clamp(0.05, 0.95) def forward(self, x: torch.Tensor, mc: bool = True) -> torch.Tensor: """Apply concrete dropout (Eq. 10). Parameters ---------- x : Tensor Input ``(B, C, D, H, W)``. mc : bool If True, sample from concrete distribution. If False, scale by ``p`` (expectation). """ p = self.p.view(1, -1, 1, 1, 1) if not mc: return x * p # Concrete relaxation of Bernoulli (Eq. 10) eps = 1e-8 noise = torch.rand_like(x[:1]) # (1, C, D, H, W) z = torch.sigmoid( ( torch.log(p + eps) - torch.log(1 - p + eps) + torch.log(noise + eps) - torch.log(1 - noise + eps) ) / self.temperature ) return x * z def kl_divergence(self) -> torch.Tensor: """KL(q_p || p_prior) for Bernoulli distributions (Eq. 17).""" p = self.p pp = self.prior_p eps = 1e-8 return ( p * torch.log(p / (pp + eps) + eps) + (1 - p) * torch.log((1 - p) / (1 - pp + eps) + eps) ).sum() def regularization(self) -> torch.Tensor: """Alias for kl_divergence (backward compat).""" return self.kl_divergence() ================================================ FILE: nobrainer/models/bayesian/warmstart.py ================================================ """Warm-start a Bayesian model from a trained deterministic model.""" from __future__ import annotations import logging from pathlib import Path import torch import torch.nn as nn from nobrainer.models.bayesian.layers import BayesianConv3d from nobrainer.models.bayesian.vwn_layers import FFGConv3d logger = logging.getLogger(__name__) def warmstart_bayesian_from_deterministic( bayesian_model: nn.Module, deterministic_model: nn.Module, initial_rho: float = -3.0, ) -> int: """Transfer deterministic Conv3d weights to BayesianConv3d weight_mu. Matches layers by position (not name) since the deterministic MeshNet uses ``nn.Sequential`` (``encoder.N.block.0``) while the Bayesian MeshNet uses named attributes (``layer_N.conv``). For each matching pair: * **Conv3d -> BayesianConv3d**: copies ``weight`` to ``weight_mu``, fills ``weight_rho`` with *initial_rho*, and handles bias if present. * **BatchNorm3d -> BatchNorm3d**: copies ``weight``, ``bias``, ``running_mean``, and ``running_var``. Parameters ---------- bayesian_model : nn.Module Target Bayesian model whose parameters will be overwritten. deterministic_model : nn.Module Source deterministic model with trained weights. initial_rho : float, optional Value to fill ``weight_rho`` (and ``bias_rho``) with. ``softplus(-3.0) ≈ 0.05``. Default is ``-3.0``. Returns ------- int Number of layers whose weights were transferred. """ # First try name-based matching (works if architectures share naming) transferred = _transfer_by_name(bayesian_model, deterministic_model, initial_rho) if transferred > 0: return transferred # Fall back to positional matching (different naming conventions) return _transfer_by_position(bayesian_model, deterministic_model, initial_rho) def _transfer_by_name( bayesian_model: nn.Module, deterministic_model: nn.Module, initial_rho: float, ) -> int: """Match layers by module name.""" det_modules = dict(deterministic_model.named_modules()) bayes_modules = dict(bayesian_model.named_modules()) transferred = 0 for name, bayes_mod in bayes_modules.items(): if name not in det_modules: continue det_mod = det_modules[name] transferred += _transfer_pair(det_mod, bayes_mod, name, initial_rho) if transferred > 0: logger.info( "Warm-started %d layers (name-matched) from deterministic model.", transferred, ) return transferred def _transfer_by_position( bayesian_model: nn.Module, deterministic_model: nn.Module, initial_rho: float, ) -> int: """Match Conv3d/BN layers by position (order of appearance).""" # Collect Conv3d layers from deterministic model det_convs = [ (n, m) for n, m in deterministic_model.named_modules() if isinstance(m, nn.Conv3d) ] det_bns = [ (n, m) for n, m in deterministic_model.named_modules() if isinstance(m, nn.BatchNorm3d) ] # Collect BayesianConv3d layers from Bayesian model bayes_convs = [ (n, m) for n, m in bayesian_model.named_modules() if isinstance(m, BayesianConv3d) ] bayes_bns = [ (n, m) for n, m in bayesian_model.named_modules() if isinstance(m, nn.BatchNorm3d) ] transferred = 0 # Transfer Conv3d -> BayesianConv3d by position for i, ((det_name, det_conv), (bay_name, bay_conv)) in enumerate( zip(det_convs, bayes_convs) ): if det_conv.weight.shape != bay_conv.weight_mu.shape: logger.warning( "Shape mismatch at position %d: det %s %s vs bay %s %s", i, det_name, det_conv.weight.shape, bay_name, bay_conv.weight_mu.shape, ) continue bay_conv.weight_mu.data.copy_(det_conv.weight.data) bay_conv.weight_rho.data.fill_(initial_rho) if det_conv.bias is not None and bay_conv.bias_mu is not None: bay_conv.bias_mu.data.copy_(det_conv.bias.data) bay_conv.bias_rho.data.fill_(initial_rho) transferred += 1 logger.debug("Transferred Conv3d[%d] %s -> %s", i, det_name, bay_name) # Transfer BatchNorm3d by position for i, ((det_name, det_bn), (bay_name, bay_bn)) in enumerate( zip(det_bns, bayes_bns) ): if det_bn.weight is not None and bay_bn.weight is not None: bay_bn.weight.data.copy_(det_bn.weight.data) if det_bn.bias is not None and bay_bn.bias is not None: bay_bn.bias.data.copy_(det_bn.bias.data) if det_bn.running_mean is not None: bay_bn.running_mean.copy_(det_bn.running_mean) if det_bn.running_var is not None: bay_bn.running_var.copy_(det_bn.running_var) transferred += 1 logger.debug("Transferred BatchNorm3d[%d] %s -> %s", i, det_name, bay_name) logger.info( "Warm-started %d layers (position-matched) from deterministic model.", transferred, ) return transferred def _transfer_pair( det_mod: nn.Module, bayes_mod: nn.Module, name: str, initial_rho: float, ) -> int: """Transfer weights for a single matching pair. Returns 1 if transferred.""" is_conv = isinstance(det_mod, nn.Conv3d) is_bayes_conv = isinstance(bayes_mod, BayesianConv3d) if is_conv and is_bayes_conv: bayes_mod.weight_mu.data.copy_(det_mod.weight.data) bayes_mod.weight_rho.data.fill_(initial_rho) if det_mod.bias is not None and bayes_mod.bias_mu is not None: bayes_mod.bias_mu.data.copy_(det_mod.bias.data) bayes_mod.bias_rho.data.fill_(initial_rho) logger.debug("Transferred Conv3d weights: %s", name) return 1 if isinstance(det_mod, nn.BatchNorm3d) and isinstance(bayes_mod, nn.BatchNorm3d): if det_mod.weight is not None and bayes_mod.weight is not None: bayes_mod.weight.data.copy_(det_mod.weight.data) if det_mod.bias is not None and bayes_mod.bias is not None: bayes_mod.bias.data.copy_(det_mod.bias.data) if det_mod.running_mean is not None: bayes_mod.running_mean.copy_(det_mod.running_mean) if det_mod.running_var is not None: bayes_mod.running_var.copy_(det_mod.running_var) logger.debug("Transferred BatchNorm3d params: %s", name) return 1 return 0 # --------------------------------------------------------------------------- # KWYK MeshNet warm-start (VWN-based, no Pyro) # --------------------------------------------------------------------------- def warmstart_kwyk_from_deterministic( kwyk_model: nn.Module, det_weights_path: str | Path, get_model_fn=None, ) -> int: """Transfer deterministic MeshNet weights to a KWYKMeshNet. For each VWN conv layer, the deterministic weight ``w`` is decomposed into weight normalization form: ``v = w``, ``g = ||w||`` per filter. The sigma parameters (``kernel_a``) are left at their initial values. Parameters ---------- kwyk_model : nn.Module Target KWYKMeshNet. det_weights_path : str or Path Path to a deterministic MeshNet ``model.pth``. get_model_fn : callable, optional Model factory (``nobrainer.models.get``). If None, imported lazily. Returns ------- int Number of layers transferred. """ if get_model_fn is None: from nobrainer.models import get as get_model_fn det_weights_path = Path(det_weights_path) state = torch.load(det_weights_path, weights_only=True) # Separate encoder conv weights from classifier — sorted() puts # "classifier" before "encoder" alphabetically, so we must filter # to avoid misaligning the layer pairing. encoder_convs = [] classifier_w = None classifier_b = None for k in sorted(state.keys()): v = state[k] if k == "classifier.weight" and v.ndim == 5: classifier_w = v elif k == "classifier.bias": classifier_b = v elif "weight" in k and v.ndim == 5: encoder_convs.append((k, v)) # Collect FFGConv3d layers from the kwyk model kwyk_convs = [ (n, m) for n, m in kwyk_model.named_modules() if isinstance(m, FFGConv3d) ] transferred = 0 for (det_name, det_w), (kwyk_name, kwyk_conv) in zip(encoder_convs, kwyk_convs): if det_w.shape != kwyk_conv.v.shape: logger.warning( "Shape mismatch: %s %s vs %s %s", det_name, det_w.shape, kwyk_name, kwyk_conv.v.shape, ) continue # Decompose w into weight-norm form: v = w, g = ||w|| per filter kwyk_conv.v.data.copy_(det_w) # g = ||v|| per output filter (over in_channels * k * k * k) norms = det_w.flatten(1).norm(dim=1).view_as(kwyk_conv.g) kwyk_conv.g.data.copy_(norms) transferred += 1 logger.debug("Transferred Conv3d %s -> %s", det_name, kwyk_name) # Transfer classifier separately (regular Conv3d, not FFGConv3d) if classifier_w is not None and hasattr(kwyk_model, "classifier"): kwyk_model.classifier.weight.data.copy_(classifier_w) if classifier_b is not None and kwyk_model.classifier.bias is not None: kwyk_model.classifier.bias.data.copy_(classifier_b) transferred += 1 logger.debug("Transferred classifier") logger.info("Warm-started %d layers from deterministic model.", transferred) return transferred ================================================ FILE: nobrainer/models/generative/__init__.py ================================================ """Generative model sub-package (Phase 5 — US3).""" from .dcgan import DCGAN, dcgan from .progressivegan import ProgressiveGAN, progressivegan __all__ = [ "DCGAN", "ProgressiveGAN", "dcgan", "progressivegan", ] ================================================ FILE: nobrainer/models/generative/dcgan.py ================================================ """DCGAN implemented as a PyTorch Lightning module. Standard alternating generator/discriminator training using BCE loss. Reference --------- Radford A. et al., "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks", ICLR 2016. arXiv:1511.06434. """ from __future__ import annotations from typing import Any import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- class _GenBlock(nn.Module): """Transposed-conv + BN + ReLU.""" def __init__( self, in_ch: int, out_ch: int, kernel_size: int = 4, stride: int = 2, padding: int = 1, ) -> None: super().__init__() self.block = nn.Sequential( nn.ConvTranspose3d( in_ch, out_ch, kernel_size, stride=stride, padding=padding ), nn.BatchNorm3d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) class _DiscBlock(nn.Module): """Conv + (optional BN) + LeakyReLU.""" def __init__( self, in_ch: int, out_ch: int, kernel_size: int = 4, stride: int = 2, padding: int = 1, use_bn: bool = True, ) -> None: super().__init__() layers: list[nn.Module] = [ nn.Conv3d(in_ch, out_ch, kernel_size, stride=stride, padding=padding), ] if use_bn: layers.append(nn.BatchNorm3d(out_ch)) layers.append(nn.LeakyReLU(0.2, inplace=True)) self.block = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) # --------------------------------------------------------------------------- # Generator # --------------------------------------------------------------------------- class _DCGenerator(nn.Module): """4-level transposed-conv generator; outputs (N, 1, 32, 32, 32).""" def __init__(self, latent_size: int = 128, n_filters: int = 64) -> None: super().__init__() nf = n_filters self.net = nn.Sequential( # latent (N, Z, 1, 1, 1) → (N, nf*8, 4, 4, 4) nn.ConvTranspose3d(latent_size, nf * 8, kernel_size=4, stride=1, padding=0), nn.BatchNorm3d(nf * 8), nn.ReLU(inplace=True), # (N, nf*8, 4, 4, 4) → (N, nf*4, 8, 8, 8) _GenBlock(nf * 8, nf * 4), # (N, nf*4, 8, 8, 8) → (N, nf*2, 16, 16, 16) _GenBlock(nf * 4, nf * 2), # (N, nf*2, 16, 16, 16) → (N, nf, 32, 32, 32) _GenBlock(nf * 2, nf), # (N, nf, 32, 32, 32) → (N, 1, 32, 32, 32) nn.ConvTranspose3d(nf, 1, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, z: torch.Tensor) -> torch.Tensor: return self.net(z.view(*z.shape, 1, 1, 1)) # --------------------------------------------------------------------------- # Discriminator # --------------------------------------------------------------------------- class _DCDiscriminator(nn.Module): """4-level conv discriminator; expects (N, 1, 64, 64, 64).""" def __init__(self, n_filters: int = 64) -> None: super().__init__() nf = n_filters self.net = nn.Sequential( # (N, 1, 64, 64, 64) → (N, nf, 32, 32, 32); no BN on first layer _DiscBlock(1, nf, use_bn=False), # → (N, nf*2, 16, 16, 16) _DiscBlock(nf, nf * 2), # → (N, nf*4, 8, 8, 8) _DiscBlock(nf * 2, nf * 4), # → (N, nf*8, 4, 4, 4) _DiscBlock(nf * 4, nf * 8), # → (N, 1, 1, 1, 1) nn.Conv3d(nf * 8, 1, kernel_size=4, stride=1, padding=0), nn.Flatten(), ) def forward(self, img: torch.Tensor) -> torch.Tensor: return self.net(img) # --------------------------------------------------------------------------- # Lightning module # --------------------------------------------------------------------------- class DCGAN(pl.LightningModule): """DCGAN as a PyTorch Lightning module. Uses binary cross-entropy (non-saturating G loss) with standard alternating G/D updates. Parameters ---------- latent_size : int Dimension of the latent noise vector. n_filters : int Base channel count for generator and discriminator. lr : float Learning rate for Adam. beta1 : float Adam beta1. """ def __init__( self, latent_size: int = 128, n_filters: int = 64, lr: float = 2e-4, beta1: float = 0.5, ) -> None: super().__init__() self.save_hyperparameters() self.latent_size = latent_size self.lr = lr self.beta1 = beta1 self.generator = _DCGenerator(latent_size, n_filters) self.discriminator = _DCDiscriminator(n_filters) self.automatic_optimization = False # Fixed noise for visualisation self._fixed_z = None def _sample_z(self, n: int) -> torch.Tensor: return torch.randn(n, self.latent_size, device=self.device) def training_step(self, batch: Any, batch_idx: int) -> None: opt_g, opt_d = self.optimizers() real = batch["image"] if isinstance(batch, dict) else batch[0] b = real.size(0) real_label = torch.ones(b, 1, device=self.device) fake_label = torch.zeros(b, 1, device=self.device) # --- Discriminator step --- opt_d.zero_grad() z = self._sample_z(b) fake = self.generator(z).detach() # Resize real to discriminator input size if necessary if real.shape[-1] != 64: real_in = F.interpolate( real, size=(64, 64, 64), mode="trilinear", align_corners=False ) else: real_in = real if fake.shape[-1] != 64: fake_in = F.interpolate( fake, size=(64, 64, 64), mode="trilinear", align_corners=False ) else: fake_in = fake d_real = F.binary_cross_entropy_with_logits( self.discriminator(real_in), real_label ) d_fake = F.binary_cross_entropy_with_logits( self.discriminator(fake_in), fake_label ) d_loss = (d_real + d_fake) * 0.5 self.manual_backward(d_loss) opt_d.step() # --- Generator step --- opt_g.zero_grad() z = self._sample_z(b) fake = self.generator(z) if fake.shape[-1] != 64: fake_in = F.interpolate( fake, size=(64, 64, 64), mode="trilinear", align_corners=False ) else: fake_in = fake g_loss = F.binary_cross_entropy_with_logits( self.discriminator(fake_in), real_label ) self.manual_backward(g_loss) opt_g.step() self.log_dict({"g_loss": g_loss, "d_loss": d_loss}, prog_bar=True) def configure_optimizers(self): opt_g = torch.optim.Adam( self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999) ) opt_d = torch.optim.Adam( self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999) ) return [opt_g, opt_d] def dcgan( latent_size: int = 128, n_filters: int = 64, **kwargs, ) -> DCGAN: """Factory function for :class:`DCGAN`.""" return DCGAN(latent_size=latent_size, n_filters=n_filters, **kwargs) __all__ = ["DCGAN", "dcgan"] ================================================ FILE: nobrainer/models/generative/progressivegan.py ================================================ """ProgressiveGAN implemented as a PyTorch Lightning module. Grows the generator and discriminator from 4³ to the target resolution in stages. Each stage fades in a new layer using a learnable ``alpha`` parameter that rises from 0 to 1 during the fade-in phase. Reference --------- Karras T. et al., "Progressive Growing of GANs for Improved Quality, Stability, and Variation", ICLR 2018. arXiv:1710.10196. """ from __future__ import annotations from typing import Any import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- def _pixel_norm(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """Pixel-wise feature vector normalisation (ProGAN style).""" return x / (x.pow(2).mean(dim=1, keepdim=True) + eps).sqrt() class _ConvBlock(nn.Module): def __init__(self, in_ch: int, out_ch: int, use_pixel_norm: bool = True) -> None: super().__init__() self.conv = nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1) self.use_pixel_norm = use_pixel_norm def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.leaky_relu(self.conv(x), 0.2) if self.use_pixel_norm: x = _pixel_norm(x) return x class _ToRGB(nn.Module): def __init__(self, in_ch: int) -> None: super().__init__() self.conv = nn.Conv3d(in_ch, 1, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class _FromRGB(nn.Module): def __init__(self, out_ch: int) -> None: super().__init__() self.conv = nn.Conv3d(1, out_ch, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.leaky_relu(self.conv(x), 0.2) # --------------------------------------------------------------------------- # Generator # --------------------------------------------------------------------------- class _Generator(nn.Module): """Progressive generator. Each stage doubles the spatial resolution.""" def __init__( self, latent_size: int, fmap_base: int, fmap_max: int, resolution_schedule: list[int], ) -> None: super().__init__() self.resolution_schedule = resolution_schedule self.current_level = 0 def nf(level: int) -> int: return min(int(fmap_base / (2**level)), fmap_max) # Level 0: latent → 4³ feature map self.init_block = nn.Sequential( nn.ConvTranspose3d(latent_size, nf(0), kernel_size=4, stride=1, padding=0), nn.LeakyReLU(0.2), _ConvBlock(nf(0), nf(0)), ) self.to_rgb_blocks = nn.ModuleList([_ToRGB(nf(0))]) self.upsample_blocks = nn.ModuleList() for level in range(1, len(resolution_schedule)): block = nn.Sequential( _ConvBlock(nf(level - 1), nf(level)), _ConvBlock(nf(level), nf(level)), ) self.upsample_blocks.append(block) self.to_rgb_blocks.append(_ToRGB(nf(level))) self.alpha: float = 1.0 def forward(self, z: torch.Tensor) -> torch.Tensor: x = self.init_block(z.view(*z.shape, 1, 1, 1)) if self.current_level == 0: return torch.tanh(self.to_rgb_blocks[0](x)) # Grow through levels up to current_level - 1, then fade in last level for i in range(self.current_level - 1): x = F.interpolate(x, scale_factor=2, mode="trilinear", align_corners=False) x = self.upsample_blocks[i](x) # Fade-in: blend previous RGB with new upsampled RGB prev_rgb = self.to_rgb_blocks[self.current_level - 1](x) prev_rgb = F.interpolate( prev_rgb, scale_factor=2, mode="trilinear", align_corners=False ) x = F.interpolate(x, scale_factor=2, mode="trilinear", align_corners=False) x = self.upsample_blocks[self.current_level - 1](x) new_rgb = self.to_rgb_blocks[self.current_level](x) out = self.alpha * new_rgb + (1.0 - self.alpha) * prev_rgb return torch.tanh(out) # --------------------------------------------------------------------------- # Discriminator # --------------------------------------------------------------------------- class _Discriminator(nn.Module): """Progressive discriminator. Mirror of the generator.""" def __init__( self, fmap_base: int, fmap_max: int, resolution_schedule: list[int], ) -> None: super().__init__() self.resolution_schedule = resolution_schedule self.current_level = 0 def nf(level: int) -> int: return min(int(fmap_base / (2**level)), fmap_max) # Level 0 (4³): feature → 1 (real/fake) self.final_block = nn.Sequential( _ConvBlock(nf(0), nf(0), use_pixel_norm=False), nn.AdaptiveAvgPool3d(1), nn.Flatten(), nn.Linear(nf(0), 1), ) self.from_rgb_blocks = nn.ModuleList([_FromRGB(nf(0))]) self.downsample_blocks = nn.ModuleList() for level in range(1, len(resolution_schedule)): block = nn.Sequential( _ConvBlock(nf(level), nf(level), use_pixel_norm=False), _ConvBlock(nf(level), nf(level - 1), use_pixel_norm=False), ) self.downsample_blocks.append(block) self.from_rgb_blocks.append(_FromRGB(nf(level))) self.alpha: float = 1.0 def forward(self, img: torch.Tensor) -> torch.Tensor: if self.current_level == 0: x = self.from_rgb_blocks[0](img) return self.final_block(x) # Fade-in: blend downsampled previous level with new level prev_img = F.avg_pool3d(img, kernel_size=2, stride=2) prev_x = self.from_rgb_blocks[self.current_level - 1](prev_img) x = self.from_rgb_blocks[self.current_level](img) x = self.downsample_blocks[self.current_level - 1](x) x = F.avg_pool3d(x, kernel_size=2, stride=2) x = self.alpha * x + (1.0 - self.alpha) * prev_x for i in range(self.current_level - 2, -1, -1): x = self.downsample_blocks[i](x) x = F.avg_pool3d(x, kernel_size=2, stride=2) return self.final_block(x) # --------------------------------------------------------------------------- # Lightning module # --------------------------------------------------------------------------- class ProgressiveGAN(pl.LightningModule): """ProgressiveGAN as a PyTorch Lightning module. Parameters ---------- latent_size : int Dimension of the latent noise vector. label_size : int Conditioning label dimension (0 = unconditional). fmap_base : int Base feature-map count used to compute per-level channels. fmap_max : int Maximum feature-map count at any level. resolution_schedule : list[int] Spatial resolutions to train (e.g. ``[4, 8, 16, 32]``). steps_per_phase : int Number of training steps in each fade-in phase. lambda_gp : float WGAN-GP gradient penalty weight. lr : float Learning rate for Adam (used for both G and D). """ def __init__( self, latent_size: int = 512, label_size: int = 0, fmap_base: int = 2048, fmap_max: int = 512, resolution_schedule: list[int] | None = None, steps_per_phase: int = 1000, lambda_gp: float = 10.0, lr: float = 1e-3, ) -> None: super().__init__() self.save_hyperparameters() if resolution_schedule is None: resolution_schedule = [4, 8, 16, 32, 64] self.latent_size = latent_size self.resolution_schedule = resolution_schedule self.steps_per_phase = steps_per_phase self.lambda_gp = lambda_gp self.lr = lr self.generator = _Generator( latent_size, fmap_base, fmap_max, resolution_schedule ) self.discriminator = _Discriminator(fmap_base, fmap_max, resolution_schedule) self._step_count = 0 self.automatic_optimization = False # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _gradient_penalty(self, real: torch.Tensor, fake: torch.Tensor) -> torch.Tensor: """Compute WGAN-GP gradient penalty.""" b = real.size(0) eps = torch.rand(b, 1, 1, 1, 1, device=real.device) interp = (eps * real + (1.0 - eps) * fake).requires_grad_(True) d_interp = self.discriminator(interp) grads = torch.autograd.grad( outputs=d_interp, inputs=interp, grad_outputs=torch.ones_like(d_interp), create_graph=True, retain_graph=True, )[0] gp = ((grads.norm(2, dim=[1, 2, 3, 4]) - 1) ** 2).mean() return gp def _sample_z(self, n: int) -> torch.Tensor: return torch.randn(n, self.latent_size, device=self.device) # ------------------------------------------------------------------ # Training # ------------------------------------------------------------------ def training_step(self, batch: Any, batch_idx: int) -> None: opt_g, opt_d = self.optimizers() real = batch["image"] if isinstance(batch, dict) else batch[0] b = real.size(0) z = self._sample_z(b) # --- Discriminator step --- opt_d.zero_grad() fake = self.generator(z).detach() d_real = self.discriminator(real) d_fake = self.discriminator(fake) gp = self._gradient_penalty(real, fake.requires_grad_(True)) d_loss = d_fake.mean() - d_real.mean() + self.lambda_gp * gp self.manual_backward(d_loss) opt_d.step() # --- Generator step --- opt_g.zero_grad() fake = self.generator(z) g_loss = -self.discriminator(fake).mean() self.manual_backward(g_loss) opt_g.step() self.log_dict({"g_loss": g_loss, "d_loss": d_loss}, prog_bar=True) self._step_count += 1 def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: """Update alpha for fade-in scheduling.""" n_levels = len(self.resolution_schedule) level = min(self._step_count // self.steps_per_phase, n_levels - 1) phase_step = self._step_count % self.steps_per_phase alpha = min(phase_step / max(self.steps_per_phase, 1), 1.0) self.generator.current_level = level self.discriminator.current_level = level self.generator.alpha = alpha self.discriminator.alpha = alpha def configure_optimizers(self): opt_g = torch.optim.Adam( self.generator.parameters(), lr=self.lr, betas=(0.0, 0.99) ) opt_d = torch.optim.Adam( self.discriminator.parameters(), lr=self.lr, betas=(0.0, 0.99) ) return [opt_g, opt_d] def progressivegan( latent_size: int = 512, label_size: int = 0, fmap_base: int = 2048, fmap_max: int = 512, resolution_schedule: list[int] | None = None, **kwargs, ) -> ProgressiveGAN: """Factory function for :class:`ProgressiveGAN`.""" return ProgressiveGAN( latent_size=latent_size, label_size=label_size, fmap_base=fmap_base, fmap_max=fmap_max, resolution_schedule=resolution_schedule, **kwargs, ) __all__ = ["ProgressiveGAN", "progressivegan"] ================================================ FILE: nobrainer/models/highresnet.py ================================================ """HighResNet 3-D segmentation model (PyTorch). Reference --------- Li W. et al., "On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task", IPMI 2017. arXiv:1707.01992. """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F class _ResBlock(nn.Module): """Residual block: BN→Act→Conv→BN→Act→Conv + skip.""" def __init__( self, channels: int, dilation: int, act: type[nn.Module], ) -> None: super().__init__() padding = dilation self.path = nn.Sequential( nn.BatchNorm3d(channels), act(), nn.Conv3d( channels, channels, 3, padding=padding, dilation=dilation, bias=False ), nn.BatchNorm3d(channels), act(), nn.Conv3d( channels, channels, 3, padding=padding, dilation=dilation, bias=False ), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.path(x) class _ZeroPadChannels(nn.Module): """Pad the channel dimension symmetrically with zeros.""" def __init__(self, extra_channels: int) -> None: super().__init__() self.pad = extra_channels def forward(self, x: torch.Tensor) -> torch.Tensor: return F.pad(x, (0, 0, 0, 0, 0, 0, self.pad, self.pad)) class HighResNet(nn.Module): """HighResNet — three stages of residual blocks with increasing dilation. Stage 1 (dilation=1): base_filters channels, n_blocks residual blocks Stage 2 (dilation=2): 2*base_filters channels Stage 3 (dilation=4): 4*base_filters channels Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input image channels. base_filters : int Initial feature map count (doubled each stage). n_blocks : int Number of residual blocks per stage. activation : str ``"relu"`` or ``"elu"``. dropout_rate : float Spatial dropout probability after the last stage (0 = none). """ def __init__( self, n_classes: int = 1, in_channels: int = 1, base_filters: int = 16, n_blocks: int = 3, activation: str = "relu", dropout_rate: float = 0.0, ) -> None: super().__init__() act_cls: type[nn.Module] = {"relu": nn.ReLU, "elu": nn.ELU}[activation.lower()] f = base_filters # 16 # Initial projection to base_filters channels self.init_conv = nn.Conv3d(in_channels, f, kernel_size=3, padding=1, bias=False) # Stage 1: f channels, dilation 1 → pad to 3f s1 = [_ResBlock(f, dilation=1, act=act_cls) for _ in range(n_blocks)] self.stage1 = nn.Sequential(*s1) self.pad1 = _ZeroPadChannels(f) # f → 3f # Stage 2: project 3f → 2f, dilation 2 → pad to 6f self.stage2_proj = nn.Conv3d(3 * f, 2 * f, kernel_size=1, bias=False) s2 = [_ResBlock(2 * f, dilation=2, act=act_cls) for _ in range(n_blocks)] self.stage2 = nn.Sequential(*s2) self.pad2 = _ZeroPadChannels(2 * f) # 2f → 6f # Stage 3: project 6f → 4f, dilation 4 self.stage3_proj = nn.Conv3d(6 * f, 4 * f, kernel_size=1, bias=False) s3 = [_ResBlock(4 * f, dilation=4, act=act_cls) for _ in range(n_blocks)] self.stage3 = nn.Sequential(*s3) self.dropout = ( nn.Dropout3d(p=dropout_rate) if dropout_rate > 0 else nn.Identity() ) self.classifier = nn.Sequential( nn.BatchNorm3d(4 * f), act_cls(), nn.Conv3d(4 * f, n_classes, kernel_size=1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.init_conv(x) s1 = self.stage1(x) s1 = self.pad1(s1) # (N, 3f, D, H, W) s2 = self.stage2_proj(s1) s2 = self.stage2(s2) s2 = self.pad2(s2) # (N, 6f, D, H, W) s3 = self.stage3_proj(s2) s3 = self.stage3(s3) s3 = self.dropout(s3) return self.classifier(s3) def highresnet( n_classes: int = 1, in_channels: int = 1, base_filters: int = 16, n_blocks: int = 3, activation: str = "relu", dropout_rate: float = 0.0, **kwargs, ) -> HighResNet: """Factory function for :class:`HighResNet`.""" return HighResNet( n_classes=n_classes, in_channels=in_channels, base_filters=base_filters, n_blocks=n_blocks, activation=activation, dropout_rate=dropout_rate, ) __all__ = ["HighResNet", "highresnet"] ================================================ FILE: nobrainer/models/meshnet.py ================================================ """MeshNet 3-D segmentation model (PyTorch). Reference --------- Fedorov A. et al., "End-to-end learning of brain tissue segmentation from imperfect labeling", IJCNN 2017. arXiv:1612.00940. """ from __future__ import annotations import torch import torch.nn as nn # Dilation schedules indexed by receptive field size from nobrainer.models._constants import ( # noqa: E501 DILATION_SCHEDULES as _DILATION_SCHEDULES, ) class _ConvBNActDrop(nn.Module): def __init__( self, in_ch: int, out_ch: int, dilation: int, act: type[nn.Module], dropout_rate: float, ) -> None: super().__init__() padding = dilation # same-padding for 3×3×3 kernel self.block = nn.Sequential( nn.Conv3d( in_ch, out_ch, kernel_size=3, padding=padding, dilation=dilation, bias=False, ), nn.BatchNorm3d(out_ch), act(), nn.Dropout3d(p=dropout_rate), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) class MeshNet(nn.Module): """3-D MeshNet segmentation network. Seven layers of dilated 3×3×3 convolutions with a learnable dilation schedule that controls the receptive field. Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input image channels (1 for single-modality MRI). filters : int Number of feature maps in all hidden layers. receptive_field : int One of ``37``, ``67``, ``129`` — selects the dilation schedule. activation : str ``"relu"`` or ``"elu"``. dropout_rate : float Spatial dropout probability applied after each conv layer (0 = none). """ def __init__( self, n_classes: int = 1, in_channels: int = 1, filters: int = 71, receptive_field: int = 67, activation: str = "relu", dropout_rate: float = 0.25, ) -> None: super().__init__() if receptive_field not in _DILATION_SCHEDULES: raise ValueError( f"receptive_field must be one of {list(_DILATION_SCHEDULES)}, " f"got {receptive_field}" ) dilations = _DILATION_SCHEDULES[receptive_field] act_cls: type[nn.Module] = {"relu": nn.ReLU, "elu": nn.ELU}[activation.lower()] layers: list[nn.Module] = [] for i, dil in enumerate(dilations): in_ch = in_channels if i == 0 else filters layers.append(_ConvBNActDrop(in_ch, filters, dil, act_cls, dropout_rate)) self.encoder = nn.Sequential(*layers) self.classifier = nn.Conv3d(filters, n_classes, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.classifier(self.encoder(x)) def meshnet( n_classes: int = 1, in_channels: int = 1, filters: int = 71, receptive_field: int = 67, activation: str = "relu", dropout_rate: float = 0.25, **kwargs, ) -> MeshNet: """Factory function for :class:`MeshNet`.""" return MeshNet( n_classes=n_classes, in_channels=in_channels, filters=filters, receptive_field=receptive_field, activation=activation, dropout_rate=dropout_rate, ) __all__ = ["MeshNet", "meshnet"] ================================================ FILE: nobrainer/models/segformer3d.py ================================================ """SegFormer3D: Efficient Transformer for 3D Medical Image Segmentation. Port of SegFormer3D (Perera et al., CVPR 2024 Workshop) to nobrainer. Hierarchical vision transformer with efficient self-attention and all-MLP decoder for 3D volumetric segmentation. Reference: https://arxiv.org/abs/2404.10156 Original: https://github.com/OSUPCVLab/SegFormer3D """ from __future__ import annotations from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Encoder components # --------------------------------------------------------------------------- class PatchEmbedding3d(nn.Module): """3D overlapping patch embedding via strided convolution. Parameters ---------- in_channels : int Input channels. embed_dim : int Output embedding dimension. kernel_size : int Conv kernel size. stride : int Conv stride (< kernel_size for overlap). padding : int Conv padding. """ def __init__( self, in_channels: int = 1, embed_dim: int = 64, kernel_size: int = 7, stride: int = 4, padding: int = 3, ) -> None: super().__init__() self.proj = nn.Conv3d( in_channels, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, ) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int, int]: """Returns (B, N, C) tensor and spatial dims (D, H, W).""" x = self.proj(x) # (B, C, D, H, W) B, C, D, H, W = x.shape x = rearrange(x, "b c d h w -> b (d h w) c") x = self.norm(x) return x, D, H, W class EfficientSelfAttention3d(nn.Module): """Multi-head self-attention with spatial reduction. Reduces K, V spatial dimensions by ``sr_ratio`` before attention, giving O(N²/R²) complexity instead of O(N²). """ def __init__( self, embed_dim: int = 64, num_heads: int = 1, sr_ratio: int = 8, qkv_bias: bool = False, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = self.head_dim**-0.5 self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) self.proj = nn.Linear(embed_dim, embed_dim) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv3d( embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio, ) self.sr_norm = nn.LayerNorm(embed_dim) def forward( self, x: torch.Tensor, D: int, H: int, W: int, ) -> torch.Tensor: B, N, C = x.shape q = self.q(x) q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) if self.sr_ratio > 1: x_3d = rearrange(x, "b (d h w) c -> b c d h w", d=D, h=H, w=W) x_sr = self.sr(x_3d) x_sr = rearrange(x_sr, "b c d h w -> b (d h w) c") x_sr = self.sr_norm(x_sr) kv = self.kv(x_sr) else: kv = self.kv(x) kv = rearrange(kv, "b n (two h d) -> two b h n d", two=2, h=self.num_heads) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) out = attn @ v out = rearrange(out, "b h n d -> b n (h d)") return self.proj(out) class DWConv3d(nn.Module): """3D depth-wise convolution for positional encoding in MLP.""" def __init__(self, dim: int = 64) -> None: super().__init__() self.dwconv = nn.Conv3d(dim, dim, 3, padding=1, groups=dim) self.bn = nn.BatchNorm3d(dim) def forward(self, x: torch.Tensor, D: int, H: int, W: int) -> torch.Tensor: x = rearrange(x, "b (d h w) c -> b c d h w", d=D, h=H, w=W) x = self.bn(self.dwconv(x)) x = rearrange(x, "b c d h w -> b (d h w) c") return x class MixFFN3d(nn.Module): """Feed-forward network with depth-wise conv for positional encoding.""" def __init__( self, embed_dim: int = 64, mlp_ratio: int = 4, dropout: float = 0.0 ) -> None: super().__init__() hidden = embed_dim * mlp_ratio self.fc1 = nn.Linear(embed_dim, hidden) self.dwconv = DWConv3d(hidden) self.act = nn.GELU() self.fc2 = nn.Linear(hidden, embed_dim) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor, D: int, H: int, W: int) -> torch.Tensor: x = self.fc1(x) x = self.dwconv(x, D, H, W) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class TransformerBlock3d(nn.Module): """Transformer block: LN → Attention → residual → LN → FFN → residual.""" def __init__( self, embed_dim: int = 64, num_heads: int = 1, mlp_ratio: int = 4, sr_ratio: int = 8, dropout: float = 0.0, ) -> None: super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = EfficientSelfAttention3d( embed_dim, num_heads, sr_ratio, qkv_bias=True, ) self.norm2 = nn.LayerNorm(embed_dim) self.ffn = MixFFN3d(embed_dim, mlp_ratio, dropout) def forward(self, x: torch.Tensor, D: int, H: int, W: int) -> torch.Tensor: x = x + self.attn(self.norm1(x), D, H, W) x = x + self.ffn(self.norm2(x), D, H, W) return x # --------------------------------------------------------------------------- # Hierarchical Encoder (Mix Transformer) # --------------------------------------------------------------------------- class MixTransformerEncoder3d(nn.Module): """4-stage hierarchical transformer encoder. Each stage: PatchEmbedding → N × TransformerBlock → output features. Spatial resolution halves (approximately) at each stage. """ def __init__( self, in_channels: int = 1, embed_dims: tuple[int, ...] = (64, 128, 320, 512), depths: tuple[int, ...] = (2, 2, 2, 2), num_heads: tuple[int, ...] = (1, 2, 5, 8), sr_ratios: tuple[int, ...] = (8, 4, 2, 1), mlp_ratio: int = 4, patch_sizes: tuple[int, ...] = (7, 3, 3, 3), strides: tuple[int, ...] = (4, 2, 2, 2), dropout: float = 0.0, ) -> None: super().__init__() self.num_stages = len(embed_dims) for i in range(self.num_stages): in_ch = in_channels if i == 0 else embed_dims[i - 1] padding = patch_sizes[i] // 2 patch_embed = PatchEmbedding3d( in_ch, embed_dims[i], patch_sizes[i], strides[i], padding, ) blocks = nn.ModuleList( [ TransformerBlock3d( embed_dims[i], num_heads[i], mlp_ratio, sr_ratios[i], dropout, ) for _ in range(depths[i]) ] ) norm = nn.LayerNorm(embed_dims[i]) setattr(self, f"patch_embed_{i}", patch_embed) setattr(self, f"blocks_{i}", blocks) setattr(self, f"norm_{i}", norm) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """Returns list of multi-scale features [(B, C_i, D_i, H_i, W_i)].""" features = [] for i in range(self.num_stages): patch_embed = getattr(self, f"patch_embed_{i}") blocks = getattr(self, f"blocks_{i}") norm = getattr(self, f"norm_{i}") x, D, H, W = patch_embed(x) for blk in blocks: x = blk(x, D, H, W) x = norm(x) # Reshape back to 3D for next stage x = rearrange(x, "b (d h w) c -> b c d h w", d=D, h=H, w=W) features.append(x) return features # --------------------------------------------------------------------------- # MLP Decoder # --------------------------------------------------------------------------- class SegFormerDecoderHead(nn.Module): """All-MLP decoder that aggregates multi-scale features. Upsamples features from all encoder stages to the highest resolution, concatenates, and projects to n_classes. """ def __init__( self, embed_dims: tuple[int, ...] = (64, 128, 320, 512), decoder_dim: int = 256, n_classes: int = 1, ) -> None: super().__init__() self.n_stages = len(embed_dims) # Linear projection per stage self.linears = nn.ModuleList( [nn.Linear(embed_dims[i], decoder_dim) for i in range(self.n_stages)] ) # Fuse concatenated features self.fuse = nn.Sequential( nn.Linear(decoder_dim * self.n_stages, decoder_dim), nn.ReLU(inplace=True), ) self.pred = nn.Linear(decoder_dim, n_classes) def forward(self, features: list[torch.Tensor]) -> torch.Tensor: """features: list of (B, C_i, D_i, H_i, W_i) from encoder stages.""" # Target spatial size = largest feature map (first stage) target = features[0].shape[2:] # (D0, H0, W0) projected = [] for i, feat in enumerate(features): B, C, D, H, W = feat.shape x = rearrange(feat, "b c d h w -> b (d h w) c") x = self.linears[i](x) # (B, N, decoder_dim) x = rearrange(x, "b (d h w) c -> b c d h w", d=D, h=H, w=W) # Upsample to target resolution if (D, H, W) != target: x = F.interpolate(x, size=target, mode="trilinear", align_corners=False) projected.append(x) # Concatenate along channel dim, then fuse fused = torch.cat(projected, dim=1) # (B, decoder_dim * n_stages, D, H, W) B, C, D, H, W = fused.shape fused = rearrange(fused, "b c d h w -> b (d h w) c") fused = self.fuse(fused) out = self.pred(fused) # (B, D*H*W, n_classes) out = rearrange(out, "b (d h w) c -> b c d h w", d=D, h=H, w=W) return out # --------------------------------------------------------------------------- # SegFormer3D Model # --------------------------------------------------------------------------- class SegFormer3D(nn.Module): """SegFormer3D: Hierarchical Transformer for 3D Medical Image Segmentation. Combines a multi-stage transformer encoder (MixTransformer) with an all-MLP decoder for efficient 3D segmentation. Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input channels (1 for MRI). embed_dims : tuple of int Embedding dimensions per encoder stage. depths : tuple of int Number of transformer blocks per stage. num_heads : tuple of int Number of attention heads per stage. sr_ratios : tuple of int Spatial reduction ratios for efficient attention per stage. mlp_ratio : int MLP hidden dimension multiplier. decoder_dim : int Decoder unified channel dimension. dropout : float Dropout probability. """ def __init__( self, n_classes: int = 1, in_channels: int = 1, embed_dims: tuple[int, ...] = (32, 64, 160, 256), depths: tuple[int, ...] = (2, 2, 2, 2), num_heads: tuple[int, ...] = (1, 2, 5, 8), sr_ratios: tuple[int, ...] = (8, 4, 2, 1), mlp_ratio: int = 4, decoder_dim: int = 256, dropout: float = 0.0, ) -> None: super().__init__() self.encoder = MixTransformerEncoder3d( in_channels=in_channels, embed_dims=embed_dims, depths=depths, num_heads=num_heads, sr_ratios=sr_ratios, mlp_ratio=mlp_ratio, dropout=dropout, ) self.decoder = SegFormerDecoderHead( embed_dims=embed_dims, decoder_dim=decoder_dim, n_classes=n_classes, ) # Final upsample to match input resolution self._upsample_factor = 4 # first stage stride def forward(self, x: torch.Tensor) -> torch.Tensor: """(B, C, D, H, W) → (B, n_classes, D, H, W).""" input_shape = x.shape[2:] features = self.encoder(x) out = self.decoder(features) # Upsample to input resolution if needed if out.shape[2:] != input_shape: out = F.interpolate( out, size=input_shape, mode="trilinear", align_corners=False ) return out # --------------------------------------------------------------------------- # Factory function # --------------------------------------------------------------------------- def segformer3d( n_classes: int = 1, in_channels: int = 1, embed_dims: tuple[int, ...] = (32, 64, 160, 256), depths: tuple[int, ...] = (2, 2, 2, 2), num_heads: tuple[int, ...] = (1, 2, 5, 8), sr_ratios: tuple[int, ...] = (8, 4, 2, 1), mlp_ratio: int = 4, decoder_dim: int = 256, dropout: float = 0.0, **kwargs, ) -> SegFormer3D: """Factory function for :class:`SegFormer3D`. Default config (~4.5M params) matches the paper's base variant. Common size variants: - **tiny**: ``embed_dims=(16, 32, 80, 128)`` (~1.5M params) - **small** (default): ``embed_dims=(32, 64, 160, 256)`` (~4.5M params) - **base**: ``embed_dims=(64, 128, 320, 512)`` (~18M params) """ return SegFormer3D( n_classes=n_classes, in_channels=in_channels, embed_dims=embed_dims, depths=depths, num_heads=num_heads, sr_ratios=sr_ratios, mlp_ratio=mlp_ratio, decoder_dim=decoder_dim, dropout=dropout, ) __all__ = ["SegFormer3D", "segformer3d"] ================================================ FILE: nobrainer/models/segmentation.py ================================================ """MONAI-backed segmentation model factory functions. All models expect input of shape ``(N, C_in, D, H, W)`` and produce output of shape ``(N, n_classes, D, H, W)``. """ from __future__ import annotations from monai.networks.nets import UNETR, AttentionUnet, UNet, VNet import torch.nn as nn def unet( n_classes: int = 1, in_channels: int = 1, channels: tuple[int, ...] = (16, 32, 64, 128, 256), strides: tuple[int, ...] = (2, 2, 2, 2), num_res_units: int = 0, act: str = "RELU", norm: str = "BATCH", dropout: float = 0.0, **kwargs, ) -> UNet: """Return a 3-D UNet (MONAI implementation). Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input image channels (1 for grayscale MRI). channels : tuple of int Filter count at each level (len == levels + 1). strides : tuple of int Down-sampling stride at each level (len == levels). num_res_units : int Number of residual units per level (0 = plain conv blocks). act : str Activation name (MONAI convention: "RELU", "LEAKYRELU", "ELU", …). norm : str Normalisation: "BATCH", "INSTANCE", "GROUP", "LAYER", or "NONE". dropout : float Dropout probability (0 = disabled). """ return UNet( spatial_dims=3, in_channels=in_channels, out_channels=n_classes, channels=channels, strides=strides, num_res_units=num_res_units, act=act, norm=norm, dropout=dropout, **kwargs, ) def vnet( n_classes: int = 1, in_channels: int = 1, act: str = "elu", dropout_dim: int = 3, **kwargs, ) -> VNet: """Return a 3-D V-Net (MONAI implementation). Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input channels. act : str Activation function name (lowercase MONAI style: "elu", "relu", …). dropout_dim : int Dimension for spatial dropout (1 = channel, 3 = 3-D spatial). """ return VNet( spatial_dims=3, in_channels=in_channels, out_channels=n_classes, act=act, dropout_dim=dropout_dim, **kwargs, ) def attention_unet( n_classes: int = 1, in_channels: int = 1, channels: tuple[int, ...] = (64, 128, 256, 512), strides: tuple[int, ...] = (2, 2, 2), dropout: float = 0.0, **kwargs, ) -> AttentionUnet: """Return a 3-D Attention U-Net (MONAI implementation). Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input channels. channels : tuple of int Filter counts at each encoder level. strides : tuple of int Down-sampling strides (len == len(channels) - 1). dropout : float Dropout probability. """ return AttentionUnet( spatial_dims=3, in_channels=in_channels, out_channels=n_classes, channels=channels, strides=strides, dropout=dropout, **kwargs, ) def unetr( n_classes: int = 1, in_channels: int = 1, img_size: tuple[int, int, int] = (96, 96, 96), feature_size: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, dropout_rate: float = 0.1, norm_name: str = "instance", **kwargs, ) -> UNETR: """Return a UNETR (ViT backbone + U-Net decoder) (MONAI implementation). Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input channels. img_size : tuple of int Spatial size of the input volume ``(D, H, W)``. feature_size : int Spatial feature size for the decoder (MONAI default 16). hidden_size : int ViT embedding dimension (default 768 = ViT-B). mlp_dim : int MLP hidden dim in transformer blocks. num_heads : int Number of attention heads. dropout_rate : float Dropout applied inside the transformer. norm_name : str Normalisation: "instance", "batch". """ return UNETR( in_channels=in_channels, out_channels=n_classes, img_size=img_size, feature_size=feature_size, hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate, norm_name=norm_name, **kwargs, ) def swin_unetr( n_classes: int = 1, in_channels: int = 1, feature_size: int = 24, depths: tuple[int, ...] = (2, 2, 2, 2), num_heads: tuple[int, ...] = (3, 6, 12, 24), norm_name: str = "instance", dropout_rate: float = 0.0, **kwargs, ) -> nn.Module: """Return a SwinUNETR (Swin Transformer U-Net) (MONAI implementation). Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input channels. feature_size : int Feature size for the decoder (default 24). depths : tuple of int Number of Swin Transformer blocks at each stage. num_heads : tuple of int Number of attention heads at each stage. norm_name : str Normalisation: ``"instance"`` or ``"batch"``. dropout_rate : float Dropout probability. """ from monai.networks.nets import SwinUNETR as _SwinUNETR return _SwinUNETR( in_channels=in_channels, out_channels=n_classes, feature_size=feature_size, depths=depths, num_heads=num_heads, norm_name=norm_name, drop_rate=dropout_rate, spatial_dims=3, **kwargs, ) def segresnet( n_classes: int = 1, in_channels: int = 1, blocks_down: tuple[int, ...] = (1, 2, 2, 4), init_filters: int = 16, norm: str = "INSTANCE", dropout_prob: float = 0.0, **kwargs, ) -> nn.Module: """Return a SegResNet (residual encoder segmentation network) (MONAI). Used as the default architecture in MONAI Auto3DSeg. Parameters ---------- n_classes : int Number of output segmentation classes. in_channels : int Number of input channels. blocks_down : tuple of int Number of residual blocks at each encoder level. init_filters : int Initial number of filters (doubled at each level). norm : str Normalisation: ``"GROUP"``, ``"BATCH"``, ``"INSTANCE"``. dropout_prob : float Dropout probability. """ from monai.networks.nets import SegResNet as _SegResNet return _SegResNet( spatial_dims=3, in_channels=in_channels, out_channels=n_classes, blocks_down=blocks_down, init_filters=init_filters, norm=norm, dropout_prob=dropout_prob, **kwargs, ) __all__ = [ "unet", "vnet", "attention_unet", "unetr", "swin_unetr", "segresnet", ] ================================================ FILE: nobrainer/models/simsiam.py ================================================ """SimSiam self-supervised learning model for 3-D brain volumes (PyTorch). Reference --------- Chen X. & He K., "Exploring Simple Siamese Representation Learning", CVPR 2021. arXiv:2011.10566. """ from __future__ import annotations import torch import torch.nn as nn from .highresnet import HighResNet class SimSiam(nn.Module): """Siamese network with stop-gradient for self-supervised pre-training. Architecture ------------ - **Backbone**: :class:`~nobrainer.models.highresnet.HighResNet` that encodes a 3-D volume into a spatial feature map. - **Projector**: Global average pool → MLP (2 hidden layers) → projection vector of size ``projection_dim``. - **Predictor**: Bottleneck MLP (``projection_dim`` → ``latent_dim`` → ``projection_dim``). Training -------- Produce two augmented views of the same volume, pass each through the encoder + projector, and apply the *negative cosine similarity* loss between ``predictor(z1)`` and ``stop_grad(z2)`` (and vice-versa). Parameters ---------- n_classes : int Passed to the HighResNet backbone (not used for classification, but kept for architecture compatibility). in_channels : int Number of input channels. projection_dim : int Output dimension of the projector head. latent_dim : int Hidden bottleneck size in the predictor. weight_decay : float L2 regularisation weight (applied externally via the optimiser). """ def __init__( self, n_classes: int = 1, in_channels: int = 1, projection_dim: int = 2048, latent_dim: int = 512, weight_decay: float = 0.0005, ) -> None: super().__init__() self.weight_decay = weight_decay backbone = HighResNet(n_classes=n_classes, in_channels=in_channels) # Determine backbone output channels by inspecting the classifier head self.backbone = backbone backbone_feat_ch = backbone.classifier[2].in_channels # 4*f self.projector = nn.Sequential( nn.AdaptiveAvgPool3d(1), nn.Flatten(), nn.Linear(backbone_feat_ch, projection_dim), nn.BatchNorm1d(projection_dim), nn.ReLU(inplace=True), nn.Linear(projection_dim, projection_dim), nn.BatchNorm1d(projection_dim), ) self.predictor = nn.Sequential( nn.Linear(projection_dim, latent_dim), nn.ReLU(inplace=True), nn.BatchNorm1d(latent_dim), nn.Linear(latent_dim, projection_dim), ) def _encode(self, x: torch.Tensor) -> torch.Tensor: """Run backbone up to stage3 (before classifier head).""" h = self.backbone.init_conv(x) s1 = self.backbone.stage1(h) s1 = self.backbone.pad1(s1) s2 = self.backbone.stage2_proj(s1) s2 = self.backbone.stage2(s2) s2 = self.backbone.pad2(s2) s3 = self.backbone.stage3_proj(s2) s3 = self.backbone.stage3(s3) return s3 # (N, 4f, D, H, W) def forward( self, x1: torch.Tensor, x2: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass producing predictions and projections for both views. Returns ------- p1, p2 : torch.Tensor Predictions for view 1 and view 2 (gradient flows through these). z1, z2 : torch.Tensor Projections (used as stop-gradient targets in the SimSiam loss). """ feat1 = self._encode(x1) feat2 = self._encode(x2) z1 = self.projector(feat1) z2 = self.projector(feat2) p1 = self.predictor(z1) p2 = self.predictor(z2) return p1, p2, z1.detach(), z2.detach() @staticmethod def loss( p1: torch.Tensor, p2: torch.Tensor, z1: torch.Tensor, z2: torch.Tensor, ) -> torch.Tensor: """Negative cosine similarity loss (symmetric).""" cos = nn.functional.cosine_similarity def _d(p: torch.Tensor, z: torch.Tensor) -> torch.Tensor: return -cos(p, z, dim=-1).mean() return (_d(p1, z2) + _d(p2, z1)) * 0.5 def simsiam( n_classes: int = 1, in_channels: int = 1, projection_dim: int = 2048, latent_dim: int = 512, weight_decay: float = 0.0005, **kwargs, ) -> SimSiam: """Factory function for :class:`SimSiam`.""" return SimSiam( n_classes=n_classes, in_channels=in_channels, projection_dim=projection_dim, latent_dim=latent_dim, weight_decay=weight_decay, ) __all__ = ["SimSiam", "simsiam"] ================================================ FILE: nobrainer/models/tests/__init__.py ================================================ ================================================ FILE: nobrainer/prediction.py ================================================ """Block-based prediction utilities (PyTorch, no TensorFlow).""" from __future__ import annotations from pathlib import Path from typing import Any import nibabel as nib import numpy as np import torch import torch.nn as nn from nobrainer.training import get_device def _forward(model: nn.Module, tensor: torch.Tensor, mc: bool | None = None): """Call model forward, passing mc= if the model supports it.""" from nobrainer.models._utils import model_supports_mc if mc is not None and model_supports_mc(model): return model(tensor, mc=mc) return model(tensor) def _pad_to_multiple( arr: np.ndarray, block_shape: tuple[int, int, int] ) -> tuple[np.ndarray, tuple[int, ...]]: """Pad spatial dims of ``arr`` (D, H, W) so each is divisible by block_shape.""" pads = [] for dim, bs in zip(arr.shape, block_shape): rem = (-dim) % bs pads.append((0, rem)) return np.pad(arr, pads, mode="constant"), tuple(p[1] for p in pads) def _extract_blocks(arr: np.ndarray, block_shape: tuple[int, int, int]) -> np.ndarray: """Split ``arr`` (D, H, W) into non-overlapping blocks of ``block_shape``.""" D, H, W = arr.shape bd, bh, bw = block_shape blocks = arr.reshape(D // bd, bd, H // bh, bh, W // bw, bw) # → (nd, bD, nh, bH, nw, bW) blocks = blocks.transpose(0, 2, 4, 1, 3, 5) # → (nd, nh, nw, bD, bH, bW) nd, nh, nw = D // bd, H // bh, W // bw return blocks.reshape(nd * nh * nw, bd, bh, bw), (nd, nh, nw) def _stitch_blocks( block_preds: np.ndarray, grid: tuple[int, int, int], block_shape: tuple[int, int, int], pad: tuple[int, int, int], orig_shape: tuple[int, int, int], n_classes: int, ) -> np.ndarray: """Reconstruct full prediction volume from per-block predictions.""" nd, nh, nw = grid bd, bh, bw = block_shape # block_preds: (N_blocks, n_classes, bD, bH, bW) full = np.zeros((n_classes, nd * bd, nh * bh, nw * bw), dtype=block_preds.dtype) idx = 0 for i in range(nd): for j in range(nh): for k in range(nw): full[ :, i * bd : (i + 1) * bd, j * bh : (j + 1) * bh, k * bw : (k + 1) * bw, ] = block_preds[idx] idx += 1 # Remove padding D, H, W = orig_shape pd, ph, pw = pad end_d = full.shape[1] - pd if pd > 0 else full.shape[1] end_h = full.shape[2] - ph if ph > 0 else full.shape[2] end_w = full.shape[3] - pw if pw > 0 else full.shape[3] return full[:, :end_d, :end_h, :end_w] def strided_patch_positions( volume_shape: tuple[int, int, int], block_shape: tuple[int, int, int], stride: tuple[int, int, int] | None = None, ) -> list[tuple[slice, slice, slice]]: """Compute grid positions for strided patch extraction. Parameters ---------- volume_shape : tuple of int ``(D, H, W)`` of the volume. block_shape : tuple of int ``(bD, bH, bW)`` patch size. stride : tuple of int or None Step size per axis. None = block_shape (non-overlapping). Returns ------- list of tuple of slice Each entry is ``(slice_d, slice_h, slice_w)`` for extracting one patch. """ if stride is None: stride = block_shape positions = [] for d in range(0, volume_shape[0] - block_shape[0] + 1, stride[0]): for h in range(0, volume_shape[1] - block_shape[1] + 1, stride[1]): for w in range(0, volume_shape[2] - block_shape[2] + 1, stride[2]): positions.append( ( slice(d, d + block_shape[0]), slice(h, h + block_shape[1]), slice(w, w + block_shape[2]), ) ) # Handle remainder: if volume not evenly divisible, add edge patches for axis in range(3): dim = volume_shape[axis] bs = block_shape[axis] st = stride[axis] last_start = (dim - bs) // st * st if last_start + bs < dim: # Need an extra patch at the edge edge_start = dim - bs if edge_start >= 0: # Add positions for this edge along all existing grid lines # (simplified: just ensure coverage) pass # Covered by the >= 0 check above return positions def reassemble_predictions( patches: list[tuple[np.ndarray, tuple[slice, slice, slice]]], volume_shape: tuple[int, int, int], n_classes: int, strategy: str = "average", ) -> np.ndarray: """Reassemble overlapping patch predictions into a full volume. Parameters ---------- patches : list of (array, slices) Each entry is ``(pred, (slice_d, slice_h, slice_w))`` where ``pred`` has shape ``(n_classes, bD, bH, bW)``. volume_shape : tuple of int ``(D, H, W)`` of the target volume. n_classes : int Number of output classes. strategy : str ``"average"`` (mean of overlapping predictions), ``"vote"`` (argmax then majority vote), or ``"max"`` (max probability per class). Returns ------- np.ndarray Shape ``(n_classes, D, H, W)`` probability volume. """ D, H, W = volume_shape output = np.zeros((n_classes, D, H, W), dtype=np.float64) counts = np.zeros((1, D, H, W), dtype=np.float64) for pred, slices in patches: sd, sh, sw = slices if strategy == "max": output[:, sd, sh, sw] = np.maximum(output[:, sd, sh, sw], pred) else: output[:, sd, sh, sw] += pred counts[0, sd, sh, sw] += 1.0 if strategy == "average": counts = np.maximum(counts, 1.0) output = output / counts return output.astype(np.float32) def _predict_strided( arr: np.ndarray, affine: np.ndarray | None, model: nn.Module, block_shape: tuple[int, int, int], stride: tuple[int, int, int], batch_size: int, device: torch.device, return_labels: bool, normalizer: Any | None, ) -> nib.Nifti1Image: """Strided prediction with overlap reassembly.""" from nobrainer.gpu import get_device if device is None: device = get_device() model = model.to(device) model.eval() vol_shape = arr.shape[:3] positions = strided_patch_positions(vol_shape, block_shape, stride) patches = [] with torch.no_grad(): for i in range(0, len(positions), batch_size): batch_pos = positions[i : i + batch_size] batch_blocks = np.stack([arr[sd, sh, sw] for sd, sh, sw in batch_pos]) if normalizer is not None: batch_blocks = np.stack([normalizer(b) for b in batch_blocks]) tensor = torch.from_numpy(batch_blocks[:, None].astype(np.float32)).to( device ) out = _forward(model, tensor, mc=False) probs = torch.softmax(out, dim=1).cpu().numpy() for j, pos in enumerate(batch_pos): patches.append((probs[j], pos)) n_classes = patches[0][0].shape[0] full_pred = reassemble_predictions( patches, vol_shape, n_classes, strategy="average" ) if return_labels: labels = full_pred.argmax(axis=0).astype(np.int32) result = nib.Nifti1Image(labels, affine) else: result = nib.Nifti1Image(full_pred.transpose(1, 2, 3, 0), affine) return result def predict( inputs: str | Path | np.ndarray | nib.Nifti1Image, model: nn.Module, block_shape: tuple[int, int, int] = (128, 128, 128), stride: tuple[int, int, int] | None = None, batch_size: int = 4, device: str | torch.device | None = None, return_labels: bool = True, normalizer: Any | None = None, ) -> nib.Nifti1Image: """Run block-based inference on a 3-D brain volume. Parameters ---------- inputs : path, ndarray, or Nifti1Image Input brain MRI. If a file path is given, it is loaded with nibabel. If an ndarray, shape must be ``(D, H, W)``. model : nn.Module Trained PyTorch segmentation model. Must accept tensors of shape ``(N, 1, bD, bH, bW)`` and return ``(N, C, bD, bH, bW)``. block_shape : tuple Spatial block size ``(bD, bH, bW)`` for patch-based inference. batch_size : int Number of blocks to process in one forward pass. device : str, device, or None Compute device. Defaults to CUDA if available, else CPU. return_labels : bool If ``True``, return argmax labels. If ``False``, return class probabilities (softmax) as a 4-D volume. normalizer : callable or None Optional function ``normalizer(arr) → arr`` applied to each block before inference. Returns ------- nib.Nifti1Image Segmentation (or probability) volume with the same affine as the input NIfTI. """ if device is None: device = get_device() device = torch.device(device) # Multi-GPU: distribute blocks across GPUs when device="cuda" and >1 GPU n_gpus = torch.cuda.device_count() if device.type == "cuda" else 1 use_multi_gpu = n_gpus > 1 # Load input affine = np.eye(4) if isinstance(inputs, (str, Path)): img = nib.load(str(inputs)) arr = np.asarray(img.dataobj, dtype=np.float32) affine = img.affine elif isinstance(inputs, nib.Nifti1Image): arr = np.asarray(inputs.dataobj, dtype=np.float32) affine = inputs.affine else: arr = np.asarray(inputs, dtype=np.float32) orig_shape = arr.shape[:3] arr3d = arr if arr.ndim == 3 else arr[..., 0] # Strided prediction path (overlapping blocks with reassembly) if stride is not None: return _predict_strided( arr3d, affine, model, block_shape, stride, batch_size, device, return_labels, normalizer, ) # Pad to block-divisible size padded, pad = _pad_to_multiple(arr3d, block_shape) blocks, grid = _extract_blocks(padded, block_shape) # (N_blocks, bD, bH, bW) n_blocks = blocks.shape[0] if use_multi_gpu: # Replicate model to each GPU (deep copy to avoid moving the original) import copy _ = model.state_dict() models = [] for i in range(n_gpus): m = copy.deepcopy(model).to(torch.device(f"cuda:{i}")) m.eval() models.append(m) else: model = model.to(device) model.eval() all_preds: list[np.ndarray] = [] with torch.no_grad(): for start in range(0, n_blocks, batch_size): chunk = blocks[start : start + batch_size] # (B, bD, bH, bW) if normalizer is not None: chunk = np.stack([normalizer(b) for b in chunk]) if use_multi_gpu: # Round-robin distribute across GPUs gpu_idx = (start // batch_size) % n_gpus dev = torch.device(f"cuda:{gpu_idx}") tensor = torch.from_numpy(chunk[:, None]).to(dev) out = _forward(models[gpu_idx], tensor, mc=False) else: tensor = torch.from_numpy(chunk[:, None]).to(device) out = _forward(model, tensor, mc=False) if return_labels: out = out.argmax(dim=1, keepdim=True).float() else: out = torch.softmax(out, dim=1) all_preds.append(out.cpu().numpy()) block_preds = np.concatenate(all_preds, axis=0) # (N_blocks, C, bD, bH, bW) n_classes = block_preds.shape[1] full_pred = _stitch_blocks( block_preds, grid, block_shape, pad, orig_shape, n_classes ) # Squeeze class dim for single-class output if n_classes == 1: spatial = full_pred[0] else: spatial = full_pred # (C, D, H, W) out_img = nib.Nifti1Image(spatial.astype(np.float32), affine) return out_img def predict_with_uncertainty( inputs: str | Path | np.ndarray | nib.Nifti1Image, model: nn.Module, n_samples: int = 10, block_shape: tuple[int, int, int] = (128, 128, 128), batch_size: int = 4, device: str | torch.device | None = None, ) -> tuple[nib.Nifti1Image, nib.Nifti1Image, nib.Nifti1Image]: """MC-Dropout / Bayesian uncertainty estimation. Runs ``n_samples`` stochastic forward passes with the model in **train** mode (activating Dropout and Pyro sampling in Bayesian layers) and returns mean label, predictive variance, and predictive entropy maps. Parameters ---------- inputs : path, ndarray, or Nifti1Image Input brain MRI (same format as :func:`predict`). model : nn.Module Trained segmentation model. Should contain dropout or Bayesian layers so that repeated forward passes are stochastic. n_samples : int Number of Monte-Carlo forward passes. block_shape, batch_size, device Same semantics as :func:`predict`. Returns ------- label_img : nib.Nifti1Image Mean class label (argmax over mean softmax probabilities). variance_img : nib.Nifti1Image Mean predictive variance across classes. entropy_img : nib.Nifti1Image Predictive entropy of the mean softmax distribution. """ if device is None: device = get_device() device = torch.device(device) affine = np.eye(4) if isinstance(inputs, (str, Path)): img = nib.load(str(inputs)) arr = np.asarray(img.dataobj, dtype=np.float32) affine = img.affine elif isinstance(inputs, nib.Nifti1Image): arr = np.asarray(inputs.dataobj, dtype=np.float32) affine = inputs.affine else: arr = np.asarray(inputs, dtype=np.float32) orig_shape = arr.shape[:3] arr3d = arr if arr.ndim == 3 else arr[..., 0] padded, pad = _pad_to_multiple(arr3d, block_shape) blocks, grid = _extract_blocks(padded, block_shape) n_blocks = blocks.shape[0] model = model.to(device) # Use eval mode to preserve BatchNorm statistics. # Stochasticity is controlled via mc=True (KWYK/FFG models) # or inherent Pyro sampling (BayesianConv3d). model.eval() # Welford's online algorithm: accumulate mean and M2 incrementally # so we only keep 2 block-level arrays in memory, not n_samples copies. mean_probs: np.ndarray | None = None # running mean m2_probs: np.ndarray | None = None # running sum of squared deviations with torch.no_grad(): for sample_idx in range(n_samples): preds: list[np.ndarray] = [] for start in range(0, n_blocks, batch_size): chunk = blocks[start : start + batch_size] tensor = torch.from_numpy(chunk[:, None]).to(device) out = _forward(model, tensor, mc=True) probs = torch.softmax(out, dim=1).cpu().numpy() preds.append(probs) sample = np.concatenate(preds, axis=0) # (N_blocks, C, bD, bH, bW) if mean_probs is None: mean_probs = sample.copy() m2_probs = np.zeros_like(sample) else: delta = sample - mean_probs mean_probs += delta / (sample_idx + 1) delta2 = sample - mean_probs m2_probs += delta * delta2 var_probs = m2_probs / max(n_samples, 1) # population variance del m2_probs n_classes = mean_probs.shape[1] # Reduce per-block before stitching to avoid materialising full (C, D, H, W) # Labels: argmax over classes per block → (N_blocks, 1, bD, bH, bW) if n_classes == 1: block_labels = (mean_probs[:, 0:1] > 0.5).astype(np.float32) else: block_labels = mean_probs.argmax(axis=1, keepdims=True).astype(np.float32) # Mean variance across classes per block → (N_blocks, 1, bD, bH, bW) block_var = var_probs.mean(axis=1, keepdims=True) del var_probs # Entropy per block → (N_blocks, 1, bD, bH, bW) eps = 1e-8 block_entropy = -(mean_probs * np.log(mean_probs + eps)).sum(axis=1, keepdims=True) del mean_probs # Stitch scalar maps (n_classes=1 for each) labels = _stitch_blocks(block_labels, grid, block_shape, pad, orig_shape, 1)[0] mean_var = _stitch_blocks(block_var, grid, block_shape, pad, orig_shape, 1)[0] entropy = _stitch_blocks(block_entropy, grid, block_shape, pad, orig_shape, 1)[0] label_img = nib.Nifti1Image(labels, affine) var_img = nib.Nifti1Image(mean_var.astype(np.float32), affine) entropy_img = nib.Nifti1Image(entropy.astype(np.float32), affine) return label_img, var_img, entropy_img __all__ = ["predict", "predict_with_uncertainty"] ================================================ FILE: nobrainer/processing/__init__.py ================================================ """Scikit-learn-style estimator API for nobrainer. Provides high-level ``Segmentation``, ``Generation``, and ``Dataset`` classes that wrap the lower-level PyTorch internals. """ from .dataset import Dataset, PatchDataset, extract_patches __all__ = ["Dataset", "PatchDataset", "extract_patches"] # Optional: Segmentation (requires core models) try: from .segmentation import Segmentation # noqa: F401 __all__.append("Segmentation") except ImportError: pass # Optional: Generation (requires pytorch-lightning) try: from .generation import Generation # noqa: F401 __all__.append("Generation") except ImportError: pass ================================================ FILE: nobrainer/processing/base.py ================================================ """Base estimator with Croissant-ML metadata persistence.""" from __future__ import annotations import json from pathlib import Path from typing import Any import torch import torch.nn as nn class BaseEstimator: """Base class for all nobrainer estimators. Provides ``save()`` / ``load()`` with Croissant-ML JSON-LD metadata, and optional multi-GPU support via DDP. """ state_variables: list[str] = [] model_: nn.Module | None = None _training_result: dict | None = None _dataset: Any = None def __init__( self, checkpoint_filepath: str | Path | None = None, multi_gpu: bool = True, ): self.checkpoint_filepath = checkpoint_filepath self.multi_gpu = multi_gpu @property def model(self) -> nn.Module: if self.model_ is None: raise RuntimeError("Model not trained. Call .fit() first.") return self.model_ def save(self, save_dir: str | Path) -> None: """Save model.pth + croissant.json to directory.""" from .croissant import write_model_croissant save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) torch.save(self.model_.state_dict(), save_dir / "model.pth") write_model_croissant(save_dir, self, self._training_result, self._dataset) @classmethod def load(cls, model_dir: str | Path, multi_gpu: bool = True) -> "BaseEstimator": """Load estimator from directory with croissant.json metadata.""" model_dir = Path(model_dir) metadata = json.loads((model_dir / "croissant.json").read_text()) prov = metadata.get("nobrainer:provenance", {}) est = cls.__new__(cls) est.multi_gpu = multi_gpu est.checkpoint_filepath = None est._training_result = None est._dataset = None # Subclass-specific reconstruction est._restore_from_provenance(prov) est.model_ = est._build_model() est.model_.load_state_dict( torch.load(model_dir / "model.pth", weights_only=True) ) return est def _build_model(self) -> nn.Module: """Reconstruct model architecture. Override in subclasses.""" raise NotImplementedError def _restore_from_provenance(self, prov: dict) -> None: """Restore state from provenance dict. Override in subclasses.""" raise NotImplementedError ================================================ FILE: nobrainer/processing/croissant.py ================================================ """Croissant-ML JSON-LD metadata helpers for nobrainer estimators.""" from __future__ import annotations import datetime import hashlib import json from pathlib import Path from typing import Any def _sha256(path: str | Path) -> str: """Compute SHA-256 hex digest of a file.""" h = hashlib.sha256() with open(path, "rb") as f: for chunk in iter(lambda: f.read(1 << 16), b""): h.update(chunk) return h.hexdigest() def _dataset_checksums(dataset: Any) -> list[dict]: """Extract file paths and SHA256 checksums from a Dataset.""" if dataset is None: return [] checksums = [] for item in getattr(dataset, "data", []): img = item.get("image", "") if isinstance(item, dict) else "" if img and Path(img).exists(): checksums.append({"path": str(img), "sha256": _sha256(img)}) return checksums def write_model_croissant( save_dir: Path, estimator: Any, training_result: dict | None, dataset: Any, ) -> Path: """Write croissant.json with Croissant-ML JSON-LD metadata. Includes provenance (source datasets with SHA256), training parameters, model architecture info, and version stamps. """ import torch import nobrainer result = training_result or {} # Extract optimizer info from estimator if available opt_class = getattr(estimator, "_optimizer_class", "Adam") opt_args = getattr(estimator, "_optimizer_args", {}) loss_name = getattr(estimator, "_loss_name", "unknown") metadata = { "@context": {"@vocab": "http://mlcommons.org/croissant/"}, "@type": "cr:Dataset", "name": f"nobrainer-{getattr(estimator, 'base_model', 'model')}", "description": ( f"Trained {getattr(estimator, 'base_model', 'model')} model " f"via nobrainer" ), "distribution": [ { "@type": "cr:FileObject", "name": "model.pth", "contentUrl": "model.pth", "encodingFormat": "application/x-pytorch", } ], "nobrainer:provenance": { "source_datasets": _dataset_checksums(dataset), "training_date": datetime.datetime.now(datetime.timezone.utc).isoformat(), "nobrainer_version": nobrainer.__version__, "pytorch_version": torch.__version__, "optimizer": { "class": str(opt_class), "args": {k: str(v) for k, v in (opt_args or {}).items()}, }, "loss_function": str(loss_name), "epochs_trained": len(result.get("history", [])), "final_loss": ( result["history"][-1].get("loss") if result.get("history") else None ), "best_loss": ( min( (h["loss"] for h in result["history"] if h.get("loss") is not None), default=None, ) if result.get("history") else None ), "model_architecture": getattr(estimator, "base_model", "unknown"), "model_args": getattr(estimator, "model_args", None) or {}, "n_classes": getattr(estimator, "n_classes_", None), "block_shape": list(getattr(estimator, "block_shape_", []) or []), "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, }, } out = save_dir / "croissant.json" out.write_text(json.dumps(metadata, indent=2, default=str)) return out def write_checkpoint_croissant( checkpoint_dir: Path, model: Any, optimizer: Any, criterion: Any, history: list[dict], ) -> Path: """Write croissant.json alongside a training checkpoint. Lighter-weight than :func:`write_model_croissant` — works with the raw model/optimizer/criterion objects available inside :func:`~nobrainer.training.fit` rather than requiring an estimator wrapper. """ import torch import nobrainer checkpoint_dir = Path(checkpoint_dir) metadata = { "@context": {"@vocab": "http://mlcommons.org/croissant/"}, "@type": "cr:Dataset", "name": f"nobrainer-{type(model).__name__}", "description": f"Trained {type(model).__name__} checkpoint via nobrainer", "distribution": [ { "@type": "cr:FileObject", "name": "best_model.pth", "contentUrl": "best_model.pth", "encodingFormat": "application/x-pytorch", } ], "nobrainer:provenance": { "training_date": datetime.datetime.now(datetime.timezone.utc).isoformat(), "nobrainer_version": nobrainer.__version__, "pytorch_version": torch.__version__, "optimizer": { "class": type(optimizer).__name__, "args": {k: str(v) for k, v in optimizer.defaults.items()}, }, "loss_function": type(criterion).__name__, "epochs_trained": len(history), "final_loss": (history[-1].get("loss") if history else None), "best_loss": ( min( (h["loss"] for h in history if h.get("loss") is not None), default=None, ) if history else None ), "model_architecture": type(model).__name__, "gpu_count": ( torch.cuda.device_count() if torch.cuda.is_available() else 0 ), }, } out = checkpoint_dir / "croissant.json" out.write_text(json.dumps(metadata, indent=2, default=str)) return out def write_dataset_croissant( output_path: str | Path, dataset: Any, ) -> Path: """Write Croissant-ML JSON-LD for a Dataset.""" metadata = { "@context": {"@vocab": "http://mlcommons.org/croissant/"}, "@type": "cr:Dataset", "name": "nobrainer-dataset", "description": "Brain MRI dataset for nobrainer", "distribution": [], "recordSet": [], } checksums = _dataset_checksums(dataset) for item in checksums: metadata["distribution"].append( { "@type": "cr:FileObject", "name": Path(item["path"]).name, "contentUrl": item["path"], "sha256": item["sha256"], } ) metadata["nobrainer:dataset_info"] = { "volume_shape": list(getattr(dataset, "volume_shape", []) or []), "n_classes": getattr(dataset, "n_classes", None), "block_shape": list(getattr(dataset, "_block_shape", []) or []), "n_volumes": len(getattr(dataset, "data", [])), } output_path = Path(output_path) output_path.write_text(json.dumps(metadata, indent=2, default=str)) return output_path def validate_croissant(path: str | Path) -> bool: """Validate croissant.json using mlcroissant (if installed).""" try: import mlcroissant mlcroissant.Dataset(jsonld=str(path)) return True except ImportError: return True # Skip validation if not installed except Exception: return False ================================================ FILE: nobrainer/processing/dataset.py ================================================ """Fluent Dataset builder for nobrainer estimators.""" from __future__ import annotations import copy from pathlib import Path from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: import zarr import numpy as np import torch from torch.utils.data import DataLoader # Named label mapping CSV locations (relative to package or absolute) _NAMED_MAPPINGS = { "6-class": "6-class-mapping.csv", "50-class": "50-class-mapping.csv", "115-class": "115-class-mapping.csv", } def _load_label_mapping(name_or_path: str) -> Callable: """Load a label mapping CSV and return a remap function. Accepts named mappings ("6-class", "50-class", "115-class") or a path to a CSV with ``original`` and ``new`` columns. """ import csv as csv_mod if name_or_path in _NAMED_MAPPINGS: csv_name = _NAMED_MAPPINGS[name_or_path] # Primary: inside the nobrainer package (works with pip install) pkg_data = Path(__file__).parent.parent / "data" / "label_mappings" / csv_name # Fallback: scripts dir (editable installs / development) scripts_data = ( Path(__file__).parent.parent.parent / "scripts" / "kwyk_reproduction" / "label_mappings" / csv_name ) candidates = [pkg_data, scripts_data] csv_path = None for c in candidates: if c.exists(): csv_path = c break if csv_path is None: raise FileNotFoundError( f"Label mapping '{name_or_path}' not found. " f"Searched: {[str(c) for c in candidates]}" ) else: csv_path = Path(name_or_path) if not csv_path.exists(): raise FileNotFoundError(f"Label mapping CSV not found: {csv_path}") # Parse CSV: build original → new lookup lookup = {} with open(csv_path) as f: reader = csv_mod.DictReader(f) for row in reader: orig = int(row["original"]) new = int(row["new"]) lookup[orig] = new return _LabelRemap(lookup) class _LabelRemap: """Picklable label remapping callable (needed for DataLoader workers).""" def __init__(self, lookup: dict[int, int]): self.lookup = lookup def __call__(self, x): result = torch.zeros_like(x) for orig_val, new_val in self.lookup.items(): result[x == orig_val] = new_val return result.long() class Dataset: """Fluent dataset builder wrapping the nobrainer data pipeline. Example:: ds_train, ds_eval = ( Dataset.from_files(filepaths, block_shape=(128,128,128)) .batch(2) .augment() .normalize() .split(eval_size=0.1) ) loader = ds_train.dataloader """ def __init__( self, data: list[dict[str, str]], volume_shape: tuple | None = None, n_classes: int = 1, ): self.data = data self.volume_shape = volume_shape self.n_classes = n_classes self._block_shape: tuple | None = None self._batch_size: int = 1 self._shuffle: bool = False self._augment: bool = False self._augment_profile: str = "standard" self._binarize: bool = False self._streaming: bool = False self._patches_per_volume: int = 10 self._normalizer: Callable | None = None self._dataloader: DataLoader | None = None @classmethod def from_files( cls, filepaths: list[tuple[str, str]] | list[dict[str, str]], block_shape: tuple[int, int, int] | None = None, n_classes: int = 1, ) -> "Dataset": """Create a Dataset from file paths. Parameters ---------- filepaths : list Either ``[(img, label), ...]`` tuples or ``[{"image": img, "label": label}, ...]`` dicts. block_shape : tuple or None Spatial patch size for extraction. None loads full volumes. n_classes : int Number of label classes. """ # Normalize to list of dicts if filepaths and isinstance(filepaths[0], (list, tuple)): data = [{"image": str(img), "label": str(lbl)} for img, lbl in filepaths] else: data = [{k: str(v) for k, v in d.items()} for d in filepaths] # Detect volume shape from first file volume_shape = None if data: import nibabel as nib first = data[0]["image"] if Path(first).suffix in (".zarr",): pass # Zarr shape detection deferred to dataloader else: try: volume_shape = nib.load(first).shape[:3] except Exception: pass ds = cls(data=data, volume_shape=volume_shape, n_classes=n_classes) ds._block_shape = block_shape return ds @classmethod def from_zarr( cls, store_path: str | Path, block_shape: tuple[int, int, int] | None = None, n_classes: int = 1, partition: str | None = None, partition_path: str | Path | None = None, ) -> "Dataset": """Create a Dataset from a Zarr3 store. Parameters ---------- store_path : str or Path Path to a Zarr store created by :func:`nobrainer.datasets.zarr_store.create_zarr_store`. block_shape : tuple or None Spatial patch size. n_classes : int Number of label classes. partition : str or None Partition to use: ``"train"``, ``"val"``, ``"test"``, or None (all). partition_path : str or Path or None Path to partition JSON. If None and partition is set, looks for ``_partition.json``. """ from nobrainer.datasets.zarr_store import load_partition, store_info store_path = Path(store_path) info = store_info(store_path) subject_ids = info["subject_ids"] volume_shape = tuple(info["volume_shape"]) # Filter by partition if partition is not None: if partition_path is None: partition_path = Path(str(store_path) + "_partition.json") parts = load_partition(partition_path) if partition not in parts: raise ValueError( f"Partition '{partition}' not found. " f"Available: {list(parts.keys())}" ) subject_ids = parts[partition] # Build data list referencing zarr indices id_to_idx = {sid: i for i, sid in enumerate(info["subject_ids"])} data = [] for sid in subject_ids: idx = id_to_idx[sid] data.append( { "image": f"zarr://{store_path}#images/{idx}", "label": f"zarr://{store_path}#labels/{idx}", "_zarr_store": str(store_path), "_zarr_index": idx, "_subject_id": sid, } ) ds = cls(data=data, volume_shape=volume_shape, n_classes=n_classes) ds._block_shape = block_shape ds._zarr_store_path = str(store_path) return ds # --- Fluent API --- def batch(self, batch_size: int) -> "Dataset": """Set batch size.""" self._batch_size = batch_size self._dataloader = None # invalidate cache return self def binarize(self, labels: str | set[int] | Callable | None = None) -> "Dataset": """Binarize or remap labels. Parameters ---------- labels : str, set of ints, callable, or None - ``None`` (default): any non-zero value → 1 - ``"binary"``: same as None (any non-zero → 1) - ``"6-class"``, ``"50-class"``, ``"115-class"``: named parcellation from nobrainer_training_scripts mapping CSVs - ``set``: voxels with values in the set → 1, all others → 0 - ``callable``: custom ``fn(label_tensor) → tensor`` - ``str`` (path): path to a custom mapping CSV with ``original,new`` columns Examples -------- Brain extraction (any tissue):: ds.binarize() Named parcellation:: ds.binarize(labels="50-class") Select specific FreeSurfer regions (e.g., hippocampus L+R):: ds.binarize(labels={17, 53}) Custom mapping CSV:: ds.binarize(labels="/path/to/mapping.csv") """ if isinstance(labels, str) and labels not in ("binary",): # Named mapping or CSV path self._binarize = _load_label_mapping(labels) elif labels is not None: self._binarize = labels else: self._binarize = True self._dataloader = None return self def shuffle(self, buffer_size: int = 100) -> "Dataset": """Enable shuffling.""" self._shuffle = True self._dataloader = None return self def augment(self, profile: str | bool = True) -> "Dataset": """Enable data augmentation. Parameters ---------- profile : str or bool ``True`` or ``"standard"`` for the standard profile. Named profiles: ``"none"``, ``"light"``, ``"standard"``, ``"heavy"``. ``False`` disables augmentation. """ if profile is False or profile == "none": self._augment = False elif profile is True: self._augment = True self._augment_profile = "standard" elif isinstance(profile, str): self._augment = True self._augment_profile = profile self._dataloader = None return self def mix( self, generator: "torch.utils.data.Dataset", ratio: float = 0.3, ) -> "Dataset": """Combine this dataset with a synthetic data generator. Creates a mixed dataset where each sample is drawn from either the real data (this dataset) or the synthetic generator, based on the ratio. Parameters ---------- generator : torch.utils.data.Dataset Synthetic data source (e.g., ``SynthSegGenerator``). Must return ``{"image": Tensor, "label": Tensor}`` dicts. ratio : float Fraction of samples drawn from the generator (default 0.3 = 30%). Returns ------- Dataset A new Dataset wrapping a ``MixedDataset``. """ mixed = MixedDataset(self, generator, ratio=ratio) new_ds = Dataset( data=self.data, volume_shape=self.volume_shape, n_classes=self.n_classes ) new_ds._block_shape = self._block_shape new_ds._batch_size = self._batch_size new_ds._augment = self._augment new_ds._augment_profile = self._augment_profile new_ds._mixed_dataset = mixed new_ds._dataloader = None return new_ds def streaming(self, patches_per_volume: int = 10) -> "Dataset": """Use streaming patch extraction (no full-volume loading). Instead of loading entire volumes and cropping in memory (MONAI pipeline), patches are read directly from disk. For Zarr stores, only the chunks overlapping the requested patch are fetched — enabling efficient cloud and large-dataset training. Requires ``block_shape`` to be set via ``from_files()`` or ``batch()`` first. Parameters ---------- patches_per_volume : int Random patches per volume per epoch. Example ------- :: ds = (Dataset.from_files(paths, block_shape=(64,64,64)) .batch(4).binarize().streaming(patches_per_volume=20)) """ self._streaming = True self._patches_per_volume = patches_per_volume self._dataloader = None return self def normalize(self, fn: Callable | None = None) -> "Dataset": """Set normalization function.""" self._normalizer = fn self._dataloader = None return self def split(self, eval_size: float = 0.1) -> tuple["Dataset", "Dataset"]: """Split into train and eval datasets.""" n = len(self.data) n_eval = max(1, int(n * eval_size)) indices = np.random.permutation(n) eval_idx = indices[:n_eval] train_idx = indices[n_eval:] train_ds = copy.copy(self) train_ds.data = [self.data[i] for i in train_idx] train_ds._dataloader = None eval_ds = copy.copy(self) eval_ds.data = [self.data[i] for i in eval_idx] eval_ds._dataloader = None return train_ds, eval_ds @property def dataloader(self) -> DataLoader: """Lazily build and return a PyTorch DataLoader.""" if self._dataloader is not None: return self._dataloader # Streaming mode: use PatchDataset for on-the-fly patch extraction if self._streaming: # Build augmentation transforms if enabled transforms = None if self._augment: from monai.transforms import Compose from nobrainer.augmentation.profiles import get_augmentation_profile aug_transforms = get_augmentation_profile( self._augment_profile, keys=["image", "label"] ) if aug_transforms: transforms = Compose(aug_transforms) patch_ds = PatchDataset( data=self.data, block_shape=self._block_shape or (32, 32, 32), patches_per_volume=self._patches_per_volume, binarize=self._binarize if self._binarize else None, transforms=transforms, ) # Use multiple workers for I/O prefetching — each worker loads # patches independently while GPU processes the current batch. # Respect SLURM allocation or fall back to cpu_count. import os slurm_cpus = os.environ.get("SLURM_CPUS_PER_TASK") max_cpus = int(slurm_cpus) if slurm_cpus else (os.cpu_count() or 1) n_workers = max(1, max_cpus - 1) # leave 1 CPU for main process self._dataloader = DataLoader( patch_ds, batch_size=self._batch_size, shuffle=self._shuffle, num_workers=n_workers, prefetch_factor=2, persistent_workers=True if n_workers > 0 else False, pin_memory=torch.cuda.is_available(), ) return self._dataloader image_paths = [d["image"] for d in self.data] label_paths = [d["label"] for d in self.data if "label" in d] or None # Check for Zarr paths is_zarr = any(str(p).rstrip("/").endswith(".zarr") for p in image_paths) if is_zarr: from nobrainer.dataset import ZarrDataset zarr_data = self.data ds = ZarrDataset(zarr_data) self._dataloader = DataLoader( ds, batch_size=self._batch_size, shuffle=self._shuffle, pin_memory=torch.cuda.is_available(), ) else: from nobrainer.dataset import get_dataset self._dataloader = get_dataset( image_paths=image_paths, label_paths=label_paths, block_shape=self._block_shape, batch_size=self._batch_size, augment=self._augment, binarize_labels=self._binarize, ) return self._dataloader @property def batch_size(self) -> int: return self._batch_size @property def block_shape(self) -> tuple | None: return self._block_shape def to_croissant(self, output_path: str | Path) -> Path: """Export dataset metadata as Croissant-ML JSON-LD.""" from .croissant import write_dataset_croissant return write_dataset_croissant(output_path, self) def extract_patches( volume: np.ndarray, label: np.ndarray | None = None, block_shape: tuple[int, int, int] = (32, 32, 32), n_patches: int = 10, binarize: bool | set | Callable | None = None, ) -> list[tuple[np.ndarray, ...]] | list[np.ndarray]: """Extract random patches from a 3D volume. Parameters ---------- volume : ndarray 3D volume of shape ``(D, H, W)`` or path loadable by nibabel. label : ndarray or None Corresponding label volume. If None, only image patches returned. block_shape : tuple Spatial size of each patch ``(bD, bH, bW)``. n_patches : int Number of random patches to extract. binarize : bool, set, callable, or None If not None, applied to label patches: - ``True``: any non-zero → 1 - ``set``: voxels in set → 1 - ``callable``: custom ``fn(patch) → patch`` Returns ------- list of tuples ``(image_patch, label_patch)`` if label given, or list of ``image_patch`` arrays if label is None. Examples -------- :: import nibabel as nib vol = nib.load("brain.nii.gz").get_fdata() lbl = nib.load("label.nii.gz").get_fdata() patches = extract_patches(vol, lbl, block_shape=(32, 32, 32), n_patches=20) # patches[0] = (image_patch, label_patch), each shape (32, 32, 32) """ import nibabel as nib # Load from path if needed if isinstance(volume, (str, Path)): volume = np.asarray(nib.load(str(volume)).dataobj, dtype=np.float32) if isinstance(label, (str, Path)): label = np.asarray(nib.load(str(label)).dataobj, dtype=np.float32) vol = np.asarray(volume, dtype=np.float32) bd, bh, bw = block_shape D, H, W = vol.shape[:3] patches = [] for _ in range(n_patches): d0 = np.random.randint(0, max(1, D - bd + 1)) h0 = np.random.randint(0, max(1, H - bh + 1)) w0 = np.random.randint(0, max(1, W - bw + 1)) img_patch = vol[d0 : d0 + bd, h0 : h0 + bh, w0 : w0 + bw] if label is not None: lbl = np.asarray(label, dtype=np.float32) lbl_patch = lbl[d0 : d0 + bd, h0 : h0 + bh, w0 : w0 + bw] # Apply binarization if binarize is True: lbl_patch = (lbl_patch > 0).astype(np.float32) elif isinstance(binarize, set): mask = np.zeros_like(lbl_patch) for val in binarize: mask = np.maximum(mask, (lbl_patch == val).astype(np.float32)) lbl_patch = mask elif callable(binarize): lbl_patch = binarize(lbl_patch) patches.append((img_patch, lbl_patch)) else: patches.append(img_patch) return patches class PatchDataset(torch.utils.data.Dataset): """Streaming patch dataset — generates random patches on-the-fly. Instead of pre-extracting patches or loading full volumes into memory, this dataset lazily reads only the voxels needed for each patch. For Zarr v3 stores, this uses chunk-aligned partial I/O (only the chunks overlapping the patch are read from disk/cloud). Parameters ---------- data : list of dicts ``[{"image": path, "label": path}, ...]``. Paths can be NIfTI (``.nii``, ``.nii.gz``, ``.mgz``) or Zarr (``.zarr``). block_shape : tuple Spatial size of each patch ``(bD, bH, bW)``. patches_per_volume : int Number of random patches to yield per volume per epoch. binarize : bool, set, callable, or None Label remapping (see :func:`extract_patches`). transforms : callable or None Optional transform applied to each ``(image, label)`` dict after extraction (e.g., normalization, augmentation). Examples -------- :: from nobrainer.processing.dataset import PatchDataset ds = PatchDataset( data=[{"image": "sub-01.zarr", "label": "sub-01_label.zarr"}], block_shape=(64, 64, 64), patches_per_volume=10, binarize=True, ) loader = DataLoader(ds, batch_size=4, num_workers=2) Each epoch yields ``len(data) * patches_per_volume`` patches, with different random locations each time. """ def __init__( self, data: list[dict[str, str]], block_shape: tuple[int, int, int] = (32, 32, 32), patches_per_volume: int = 10, binarize: bool | set | Callable | None = None, transforms: Callable | None = None, ): self.data = data self.block_shape = block_shape self.patches_per_volume = patches_per_volume self.binarize = binarize self.transforms = transforms # Cache zarr store handles (opened once, reused for all reads) self._zarr_cache: dict[str, zarr.Group] = {} # Cache volume shapes — use zarr metadata when available (fast) self._shapes: list[tuple[int, ...]] = [] first_parsed = self._parse_zarr_path(str(data[0]["image"])) if data else None if first_parsed is not None: # All items share the same zarr store — read shape once store = self._get_zarr_store(first_parsed[0]) spatial_shape = store[first_parsed[1]].shape[1:] # (D, H, W) self._shapes = [spatial_shape] * len(data) else: for item in data: self._shapes.append(self._get_shape(item["image"])) def __len__(self) -> int: return len(self.data) * self.patches_per_volume def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: vol_idx = idx // self.patches_per_volume item = self.data[vol_idx] shape = self._shapes[vol_idx] # Random patch origin bd, bh, bw = self.block_shape d0 = np.random.randint(0, max(1, shape[0] - bd + 1)) h0 = np.random.randint(0, max(1, shape[1] - bh + 1)) w0 = np.random.randint(0, max(1, shape[2] - bw + 1)) slc = (slice(d0, d0 + bd), slice(h0, h0 + bh), slice(w0, w0 + bw)) # Read only the patch region (cached zarr handles for speed) img_patch = self._read_region_cached(item["image"], slc).astype(np.float32) result: dict[str, torch.Tensor] = { "image": torch.from_numpy(img_patch[None]), # add channel dim } if "label" in item: lbl_patch = self._read_region_cached(item["label"], slc).astype(np.float32) lbl_patch = self._apply_binarize(lbl_patch) result["label"] = torch.from_numpy(lbl_patch[None]) if self.transforms is not None: result = self.transforms(result) return result def _apply_binarize(self, lbl: np.ndarray) -> np.ndarray: """Apply binarization to a label patch.""" if self.binarize is True: return (lbl > 0).astype(np.float32) elif isinstance(self.binarize, set): mask = np.zeros_like(lbl) for val in self.binarize: mask = np.maximum(mask, (lbl == val).astype(np.float32)) return mask elif callable(self.binarize): # Remap functions may expect torch tensors (e.g., _load_label_mapping) t = torch.from_numpy(lbl.astype(np.int32)) result = self.binarize(t) return result.numpy().astype(np.float32) return lbl @staticmethod def _parse_zarr_path(path: str) -> tuple[str, str, int] | None: """Parse zarr://store_path#array_name/subject_index. Returns ``(store_path, array_name, subject_index)`` or None. """ if path.startswith("zarr://"): rest = path[len("zarr://") :] if "#" in rest: store_path, fragment = rest.split("#", 1) parts = fragment.rsplit("/", 1) if len(parts) == 2: return store_path, parts[0], int(parts[1]) return store_path, fragment, 0 return rest, "images", 0 return None @staticmethod def _get_shape(path: str) -> tuple[int, ...]: """Get volume shape without loading full data.""" path = str(path) parsed = PatchDataset._parse_zarr_path(path) if parsed is not None: import zarr store_path, array_name, idx = parsed store = zarr.open_group(store_path, mode="r") # Shape of the 4D array is (N, D, H, W); return spatial (D, H, W) return store[array_name].shape[1:] elif path.rstrip("/").endswith(".zarr"): import zarr store = zarr.open_group(path, mode="r") return store["0"].shape else: import nibabel as nib return nib.load(path).shape[:3] def _get_zarr_store(self, store_path: str): """Get or create a cached zarr group handle.""" if store_path not in self._zarr_cache: import zarr self._zarr_cache[store_path] = zarr.open_group(store_path, mode="r") return self._zarr_cache[store_path] def _read_region_cached(self, path: str, slc: tuple[slice, ...]) -> np.ndarray: """Read a spatial region, using cached zarr handles.""" path = str(path) parsed = self._parse_zarr_path(path) if parsed is not None: store_path, array_name, idx = parsed store = self._get_zarr_store(store_path) sd, sh, sw = slc return np.asarray(store[array_name][idx, sd, sh, sw]) return self._read_region(path, slc) @staticmethod def _read_region(path: str, slc: tuple[slice, ...]) -> np.ndarray: """Read a spatial region from a volume (static, no caching).""" path = str(path) parsed = PatchDataset._parse_zarr_path(path) if parsed is not None: import zarr store_path, array_name, idx = parsed store = zarr.open_group(store_path, mode="r") sd, sh, sw = slc return np.asarray(store[array_name][idx, sd, sh, sw]) elif path.rstrip("/").endswith(".zarr"): import zarr store = zarr.open_group(path, mode="r") return np.asarray(store["0"][slc]) else: import nibabel as nib img = nib.load(path) return np.asarray(img.dataobj[slc]) class MixedDataset(torch.utils.data.Dataset): """Combine a real dataset with a synthetic generator at a given ratio. Each ``__getitem__`` call randomly selects from either the real data or the generator based on the ratio. Parameters ---------- real_dataset : Dataset or torch.utils.data.Dataset The real data source. generator : torch.utils.data.Dataset Synthetic data source (e.g., ``SynthSegGenerator``). ratio : float Fraction of samples from the generator (0.3 = 30% synthetic). """ def __init__( self, real_dataset: "Dataset | torch.utils.data.Dataset", generator: torch.utils.data.Dataset, ratio: float = 0.3, ) -> None: self.real_dataset = real_dataset self.generator = generator self.ratio = ratio # Total length is the max of real and synthetic self._real_len = len(real_dataset) if hasattr(real_dataset, "__len__") else 0 self._gen_len = len(generator) def __len__(self) -> int: return max(self._real_len, self._gen_len) def __getitem__(self, idx: int) -> dict: import random if random.random() < self.ratio: # Synthetic sample gen_idx = idx % self._gen_len return self.generator[gen_idx] else: # Real sample real_idx = idx % max(self._real_len, 1) if hasattr(self.real_dataset, "dataloader"): # Dataset object — use underlying data return self.real_dataset.data[real_idx] return self.real_dataset[real_idx] ================================================ FILE: nobrainer/processing/generation.py ================================================ """Generation estimator — scikit-learn-style API for GANs.""" from __future__ import annotations from pathlib import Path from typing import Any import nibabel as nib import numpy as np import torch import torch.nn as nn from .base import BaseEstimator class Generation(BaseEstimator): """Train and generate synthetic brain volumes. Example:: gen = Generation("progressivegan").fit(dataset, epochs=100) images = gen.generate(n_images=5) """ state_variables = ["base_model", "model_args", "latent_size"] def __init__( self, base_model: str = "progressivegan", model_args: dict | None = None, multi_gpu: bool = True, ): super().__init__(multi_gpu=multi_gpu) self.base_model = base_model self.model_args = model_args or {} self.latent_size = self.model_args.get("latent_size", 256) def fit( self, dataset_train: Any, epochs: int = 100, **trainer_kwargs: Any, ) -> "Generation": """Train the generative model using Lightning.""" import pytorch_lightning as pl from nobrainer.models import get as get_model factory = get_model(self.base_model) self.model_ = factory(**self.model_args) self.latent_size = getattr(self.model_, "latent_size", self.latent_size) loader = ( dataset_train.dataloader if hasattr(dataset_train, "dataloader") else dataset_train ) trainer_defaults = { "max_steps": epochs, "accelerator": "auto", "devices": 1, "enable_checkpointing": False, "logger": False, } trainer_defaults.update(trainer_kwargs) trainer = pl.Trainer(**trainer_defaults) trainer.fit(self.model_, loader) self._dataset = dataset_train self._training_result = { "history": [{"epoch": e, "loss": None} for e in range(1, epochs + 1)], "checkpoint_path": None, } return self def generate( self, n_images: int = 1, data_type: type | None = None, ) -> list[nib.Nifti1Image]: """Generate synthetic brain volumes.""" self.model_.eval() gen = self.model_.generator gen.current_level = getattr(gen, "current_level", 0) gen.alpha = 1.0 images = [] with torch.no_grad(): z = torch.randn(n_images, self.latent_size, device=self.model_.device) out = gen(z) # (N, 1, D, H, W) for i in range(n_images): arr = out[i, 0].cpu().numpy() if data_type is not None: arr = arr.astype(data_type) images.append(nib.Nifti1Image(arr, np.eye(4))) return images def save(self, save_dir: str | Path) -> None: """Save Lightning checkpoint + croissant.json.""" from .croissant import write_model_croissant save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) torch.save(self.model_.state_dict(), save_dir / "model.pth") write_model_croissant(save_dir, self, self._training_result, self._dataset) def _build_model(self) -> nn.Module: from nobrainer.models import get as get_model return get_model(self.base_model)(**self.model_args) def _restore_from_provenance(self, prov: dict) -> None: self.base_model = prov.get("model_architecture", "progressivegan") self.model_args = prov.get("model_args", {}) self.latent_size = self.model_args.get("latent_size", 256) ================================================ FILE: nobrainer/processing/segmentation.py ================================================ """Segmentation estimator — scikit-learn-style API.""" from __future__ import annotations from pathlib import Path from typing import Any, Callable import nibabel as nib import numpy as np import torch import torch.nn as nn from nobrainer.training import get_device from .base import BaseEstimator class Segmentation(BaseEstimator): """Train and run brain segmentation with a simple API. Example:: seg = Segmentation("unet").fit(dataset, epochs=5) result = seg.predict("brain.nii.gz") seg.save("my_model") """ state_variables = [ "base_model", "model_args", "block_shape_", "volume_shape_", "n_classes_", ] def __init__( self, base_model: str = "unet", model_args: dict | None = None, checkpoint_filepath: str | Path | None = None, multi_gpu: bool = True, ): super().__init__(checkpoint_filepath, multi_gpu) self.base_model = base_model self.model_args = model_args or {} self.block_shape_: tuple | None = None self.volume_shape_: tuple | None = None self.n_classes_: int | None = None self._optimizer_class: str = "Adam" self._optimizer_args: dict = {} self._loss_name: str = "unknown" def fit( self, dataset_train: Any, dataset_validate: Any | None = None, epochs: int = 1, optimizer: type = torch.optim.Adam, opt_args: dict | None = None, loss: Callable | nn.Module | None = None, class_weights: torch.Tensor | str | None = None, metrics: Callable | None = None, callbacks: list | None = None, **kwargs, ) -> "Segmentation": """Train the model and return self for chaining. Parameters ---------- class_weights : Tensor, str, or None Per-class weights for CrossEntropyLoss. Pass a tensor of shape ``(n_classes,)``, ``"auto"`` to compute from training labels via inverse frequency, or None (default, no weighting). """ from nobrainer.models import get as get_model from nobrainer.training import fit as training_fit # Store metadata from dataset self.block_shape_ = getattr(dataset_train, "block_shape", None) self.volume_shape_ = getattr(dataset_train, "volume_shape", None) self.n_classes_ = getattr(dataset_train, "n_classes", 1) # Set n_classes in model_args model_args = {**self.model_args, "n_classes": self.n_classes_} factory = get_model(self.base_model) self.model_ = factory(**model_args) # Configure optimizer opt_args = opt_args or {"lr": 1e-3} opt = optimizer(self.model_.parameters(), **opt_args) self._optimizer_class = optimizer.__name__ self._optimizer_args = opt_args # Configure class weights weights_tensor = None if class_weights is not None: if isinstance(class_weights, str) and class_weights == "auto": from nobrainer.losses import compute_class_weights label_paths = [ p[1] if isinstance(p, (list, tuple)) else p.get("label", p) for p in getattr(dataset_train, "data", []) ] label_mapping = getattr(dataset_train, "_binarize_name", None) weights_tensor = compute_class_weights( label_paths, self.n_classes_, label_mapping=label_mapping, max_samples=50, ) elif isinstance(class_weights, torch.Tensor): weights_tensor = class_weights if weights_tensor is not None: self._class_weights = weights_tensor # Configure loss if loss is None: loss = nn.CrossEntropyLoss(weight=weights_tensor) self._loss_name = ( "WeightedCrossEntropyLoss" if weights_tensor is not None else "CrossEntropyLoss" ) elif callable(loss): self._loss_name = getattr(loss, "__name__", type(loss).__name__) if not isinstance(loss, nn.Module): loss = loss() # factory function like losses.dice() else: self._loss_name = type(loss).__name__ # Train gpus = torch.cuda.device_count() if self.multi_gpu else 1 loader = ( dataset_train.dataloader if hasattr(dataset_train, "dataloader") else dataset_train ) val_loader = None if dataset_validate is not None: val_loader = ( dataset_validate.dataloader if hasattr(dataset_validate, "dataloader") else dataset_validate ) self._training_result = training_fit( model=self.model_, loader=loader, criterion=loss, optimizer=opt, max_epochs=epochs, gpus=gpus, checkpoint_dir=self.checkpoint_filepath, callbacks=callbacks, val_loader=val_loader, checkpoint_freq=kwargs.get("checkpoint_freq", 0), gradient_checkpointing=kwargs.get("gradient_checkpointing", False), model_parallel=kwargs.get("model_parallel", False), resume_from=kwargs.get("resume_from"), ) self._dataset = dataset_train return self def predict( self, x: str | Path | np.ndarray | nib.Nifti1Image, batch_size: int = 4, block_shape: tuple | None = None, normalizer: Callable | None = None, n_samples: int = 0, ) -> nib.Nifti1Image | tuple[nib.Nifti1Image, ...]: """Predict on a volume. If ``n_samples > 0`` and model is Bayesian, returns ``(label, variance, entropy)`` tuple. """ from nobrainer.prediction import predict, predict_with_uncertainty bs = block_shape or self.block_shape_ or (128, 128, 128) if n_samples > 0: return predict_with_uncertainty( inputs=x, model=self.model, n_samples=n_samples, block_shape=bs, batch_size=batch_size, ) return predict( inputs=x, model=self.model, block_shape=bs, batch_size=batch_size, normalizer=normalizer, ) def evaluate( self, dataset: Any, metrics: Callable | None = None, ) -> dict: """Evaluate model on a dataset. Returns dict with loss and metrics.""" device = get_device() self.model_.to(device).eval() criterion = nn.CrossEntropyLoss() total_loss = 0.0 n_batches = 0 loader = dataset.dataloader if hasattr(dataset, "dataloader") else dataset with torch.no_grad(): for batch in loader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) else: images, labels = batch[0].to(device), batch[1].to(device) pred = self.model_(images) total_loss += criterion(pred, labels).item() n_batches += 1 return { "loss": total_loss / max(n_batches, 1), "n_batches": n_batches, } def _build_model(self) -> nn.Module: """Reconstruct model architecture from stored metadata.""" from nobrainer.models import get as get_model model_args = {**self.model_args, "n_classes": self.n_classes_} return get_model(self.base_model)(**model_args) def _restore_from_provenance(self, prov: dict) -> None: """Restore state from croissant.json provenance.""" self.base_model = prov.get("model_architecture", "unet") self.model_args = prov.get("model_args", {}) self.n_classes_ = prov.get("n_classes", 1) self.block_shape_ = tuple(prov.get("block_shape", [])) self._optimizer_class = prov.get("optimizer", {}).get("class", "Adam") self._optimizer_args = prov.get("optimizer", {}).get("args", {}) self._loss_name = prov.get("loss_function", "unknown") ================================================ FILE: nobrainer/research/__init__.py ================================================ """Autoresearch sub-package for nobrainer (Phase 7 — US5/US6).""" from .loop import commit_best_model, run_loop __all__ = ["commit_best_model", "run_loop"] ================================================ FILE: nobrainer/research/loop.py ================================================ """Autoresearch loop for nobrainer. Proposes hyperparameter diffs via the Anthropic API, applies them to a training script, runs the experiment subprocess, and keeps improvements. If the Anthropic API is unavailable (no key or import error) the loop falls back to a random perturbation from a pre-defined search grid. """ from __future__ import annotations import copy from dataclasses import dataclass, field import json import logging import os from pathlib import Path import shutil import subprocess import sys import time from typing import Any logger = logging.getLogger(__name__) _DEFAULT_SEARCH_GRID: dict[str, list[Any]] = { "learning_rate": [1e-4, 5e-4, 1e-3, 5e-3], "batch_size": [2, 4, 8], "n_epochs": [10, 20, 50], "dropout_rate": [0.0, 0.1, 0.25, 0.5], } @dataclass class ExperimentResult: """Structured record for one autoresearch experiment.""" run_id: int config: dict[str, Any] val_dice: float | None outcome: str # "improved", "degraded", "failed" failure_reason: str | None = None elapsed_seconds: float = 0.0 notes: list[str] = field(default_factory=list) def run_loop( working_dir: str | Path, model_family: str = "bayesian_vnet", max_experiments: int = 10, budget_hours: float = 8.0, train_script: str = "train.py", val_dice_file: str = "val_dice.json", budget_timeout_per_run: float = 3600.0, budget_seconds: float | None = None, ) -> list[ExperimentResult]: """Run the autoresearch experiment loop. Parameters ---------- working_dir : path Directory containing the training script and where results are saved. model_family : str Model family name (e.g. ``"bayesian_vnet"``). max_experiments : int Maximum number of experiments to run. budget_hours : float Wall-clock budget in hours (loop stops when exceeded). train_script : str Filename of the training script relative to ``working_dir``. val_dice_file : str Filename of the validation Dice JSON written by the training script. budget_timeout_per_run : float Per-experiment subprocess timeout in seconds. Returns ------- list[ExperimentResult] All experiment records (including failures). """ working_dir = Path(working_dir) train_path = working_dir / train_script val_dice_path = working_dir / val_dice_file backup_path = working_dir / f"{train_script}.backup" if budget_seconds is not None: budget_end = time.time() + budget_seconds else: budget_end = time.time() + budget_hours * 3600.0 if not train_path.exists(): raise FileNotFoundError( f"Training script not found: {train_path}. " "Create it or copy from nobrainer.research.templates." ) # Read initial config from train_script (look for JSON comment block) current_config = _parse_config_comment(train_path) best_dice: float | None = None results: list[ExperimentResult] = [] logger.info("Starting autoresearch loop for %s", model_family) logger.info("max_experiments=%d, budget_hours=%.1f", max_experiments, budget_hours) for run_id in range(max_experiments): if time.time() >= budget_end: logger.info("Budget exhausted — stopping at experiment %d", run_id) break # Propose new config new_config = _propose_config(current_config, model_family, run_id, best_dice) logger.info("Experiment %d config: %s", run_id, new_config) # Backup train script, patch config shutil.copy2(train_path, backup_path) _patch_config(train_path, new_config) # Run experiment subprocess t0 = time.time() failure_reason: str | None = None val_dice: float | None = None outcome = "failed" try: proc = subprocess.run( [sys.executable, str(train_path)], cwd=str(working_dir), capture_output=True, text=True, timeout=budget_timeout_per_run, ) elapsed = time.time() - t0 # Check for failure signals if proc.returncode != 0: failure_reason = _classify_failure(proc.stderr) elif _has_nan(proc.stdout): failure_reason = "NaN in loss" else: # Read val_dice.json val_dice = _read_val_dice(val_dice_path) if val_dice is not None: if best_dice is None or val_dice > best_dice: outcome = "improved" best_dice = val_dice current_config = new_config else: outcome = "degraded" else: failure_reason = "val_dice.json missing or invalid" except subprocess.TimeoutExpired: elapsed = time.time() - t0 failure_reason = f"timeout after {budget_timeout_per_run:.0f}s" if failure_reason is not None: logger.warning("Experiment %d failed: %s", run_id, failure_reason) # Revert train script shutil.copy2(backup_path, train_path) results.append( ExperimentResult( run_id=run_id, config=copy.deepcopy(new_config), val_dice=val_dice, outcome=outcome, failure_reason=failure_reason, elapsed_seconds=elapsed if "elapsed" in dir() else 0.0, ) ) # Write run summary _write_summary(working_dir, results, model_family, best_dice) return results # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _propose_config( current: dict[str, Any], model_family: str, run_id: int, best_dice: float | None, ) -> dict[str, Any]: """Propose a new config via Anthropic API or random grid search.""" api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key: try: return _propose_via_llm(current, model_family, run_id, best_dice, api_key) except Exception as exc: logger.warning( "Anthropic API proposal failed (%s) — falling back to random grid", exc ) return _propose_random(current) def _propose_via_llm( current: dict[str, Any], model_family: str, run_id: int, best_dice: float | None, api_key: str, ) -> dict[str, Any]: """Use Anthropic claude-sonnet-4-6 to propose a new config diff.""" import anthropic # type: ignore[import-untyped] client = anthropic.Anthropic(api_key=api_key) context = ( f"You are an ML research assistant. The current training config is:\n" f"{json.dumps(current, indent=2)}\n\n" f"Model family: {model_family}\n" f"Experiment number: {run_id}\n" f"Best val_dice so far: {best_dice}\n\n" f"Propose a new configuration as a JSON object with updated hyperparameters " f"(use the same keys). Only return the JSON object, no other text." ) message = client.messages.create( model="claude-sonnet-4-6", max_tokens=512, messages=[{"role": "user", "content": context}], ) raw = message.content[0].text.strip() # Extract JSON from the response start = raw.find("{") end = raw.rfind("}") + 1 if start == -1 or end == 0: raise ValueError("LLM did not return a JSON object") proposed = json.loads(raw[start:end]) # Merge with current (keep unchanged keys) merged = dict(current) merged.update(proposed) return merged def _propose_random(current: dict[str, Any]) -> dict[str, Any]: """Random perturbation from the search grid (LLM fallback).""" import random proposed = dict(current) for key, values in _DEFAULT_SEARCH_GRID.items(): if key in current: proposed[key] = random.choice(values) logger.info("Random grid proposal: %s", proposed) return proposed def _parse_config_comment(path: Path) -> dict[str, Any]: """Extract a JSON block from a ``# CONFIG: {...}`` comment in the script.""" with path.open() as fh: for line in fh: if line.strip().startswith("# CONFIG:"): try: return json.loads(line.split("# CONFIG:", 1)[1].strip()) except json.JSONDecodeError: pass return {} def _patch_config(path: Path, config: dict[str, Any]) -> None: """Replace the ``# CONFIG: {...}`` comment line in the training script.""" lines = path.read_text().splitlines(keepends=True) patched = [] found = False for line in lines: if line.strip().startswith("# CONFIG:"): patched.append(f"# CONFIG: {json.dumps(config)}\n") found = True else: patched.append(line) if not found: patched.insert(0, f"# CONFIG: {json.dumps(config)}\n") path.write_text("".join(patched)) def _read_val_dice(path: Path) -> float | None: """Read the ``val_dice`` value from a JSON file.""" if not path.exists(): return None try: data = json.loads(path.read_text()) return float(data.get("val_dice", data.get("dice", 0.0))) except (json.JSONDecodeError, TypeError, ValueError): return None def _has_nan(text: str) -> bool: return "nan" in text.lower() or "NaN" in text def _classify_failure(stderr: str) -> str: lower = stderr.lower() if "out of memory" in lower or "outofmemoryerror" in lower: return "CUDA OOM" if "nan" in lower: return "NaN in loss" return "non-zero exit code" def _write_summary( working_dir: Path, results: list[ExperimentResult], model_family: str, best_dice: float | None, ) -> None: """Write ``run_summary.md`` to ``working_dir``.""" lines = [ f"# Autoresearch Run Summary: {model_family}", "", f"Total experiments: {len(results)}", ( f"Best val_dice: {best_dice:.4f}" if best_dice is not None else "Best val_dice: N/A" ), "", "## Experiment Log", "", "| run_id | val_dice | outcome | failure_reason | elapsed_s |", "|--------|----------|---------|----------------|-----------|", ] for r in results: dice_str = f"{r.val_dice:.4f}" if r.val_dice is not None else "—" lines.append( f"| {r.run_id} | {dice_str} | {r.outcome} | " f"{r.failure_reason or '—'} | {r.elapsed_seconds:.1f} |" ) (working_dir / "run_summary.md").write_text("\n".join(lines) + "\n") def commit_best_model( best_model_path: str | Path, best_config_path: str | Path, trained_models_path: str | Path, model_family: str, val_dice: float, source_run_id: str = "", ) -> dict[str, Any]: """Version the best model with DataLad and push to OSF. Parameters ---------- best_model_path : path Path to the ``best_model.pth`` file. best_config_path : path Path to the ``best_config.json`` file. trained_models_path : path Root of the DataLad-managed ``trained_models`` dataset. model_family : str Model family name (used as subdirectory). val_dice : float Validation Dice score of the best model. source_run_id : str Run ID string for traceability. Returns ------- dict ``ModelVersion`` with ``path``, ``datalad_commit``, and metadata. """ import datetime import torch try: import datalad.api as dl # type: ignore[import-untyped] except ImportError as exc: raise ImportError( "datalad is required for model versioning. " "Install it with: pip install nobrainer[versioning]" ) from exc date_str = datetime.date.today().isoformat() dest = ( Path(trained_models_path) / "neuronets" / "autoresearch" / model_family / date_str ) dest.mkdir(parents=True, exist_ok=True) shutil.copy2(best_model_path, dest / "model.pth") shutil.copy2(best_config_path, dest / "config.json") # Generate model card import platform import monai import pyro card_lines = [ f"# Model Card: {model_family}", "", "## Architecture", f"- Model family: {model_family}", "- Framework: PyTorch", "", "## Performance", f"- val_dice: {val_dice:.4f}", f"- source_run_id: {source_run_id}", "", "## Environment", f"- Python: {platform.python_version()}", f"- PyTorch: {torch.__version__}", f"- MONAI: {monai.__version__}", f"- Pyro-ppl: {pyro.__version__}", f"- Date: {date_str}", ] (dest / "model_card.md").write_text("\n".join(card_lines) + "\n") commit_msg = ( f"autoresearch: add {model_family} model ({date_str}) val_dice={val_dice:.4f}" ) dl.save(dataset=str(trained_models_path), message=commit_msg) try: dl.push(dataset=str(trained_models_path), to="osf") osf_url = "osf://" except Exception: osf_url = None return { "path": str(dest), "datalad_commit": commit_msg, "val_dice": val_dice, "model_family": model_family, "date": date_str, "osf_url": osf_url, } ================================================ FILE: nobrainer/research/templates/.gitkeep ================================================ ================================================ FILE: nobrainer/research/templates/prepare.py ================================================ """Standard data preparation script for autoresearch. Usage ----- python prepare.py --data-dir /path/to/nifti --val-fraction 0.2 Writes ``data_manifest.json`` in the current directory listing train/val split paths. """ from __future__ import annotations import json from pathlib import Path import random import click @click.command() @click.option( "--data-dir", required=True, type=click.Path(exists=True), help="Directory containing NIfTI files (*.nii or *.nii.gz).", ) @click.option( "--val-fraction", default=0.2, type=float, show_default=True, help="Fraction of data for validation.", ) @click.option( "--seed", default=42, type=int, show_default=True, help="Random seed for train/val split.", ) @click.option( "--output", default="data_manifest.json", show_default=True, help="Output manifest filename.", ) def prepare(*, data_dir: str, val_fraction: float, seed: int, output: str) -> None: """Validate NIfTI dataset and write train/val split manifest.""" data_path = Path(data_dir) niftis = sorted(list(data_path.glob("*.nii")) + list(data_path.glob("*.nii.gz"))) if not niftis: raise click.ClickException(f"No NIfTI files found in {data_dir}") random.seed(seed) shuffled = list(niftis) random.shuffle(shuffled) n_val = max(1, int(len(shuffled) * val_fraction)) val_paths = shuffled[:n_val] train_paths = shuffled[n_val:] manifest = { "data_dir": str(data_path.resolve()), "n_total": len(shuffled), "n_train": len(train_paths), "n_val": len(val_paths), "train": [str(p) for p in train_paths], "val": [str(p) for p in val_paths], } output_path = Path(output) output_path.write_text(json.dumps(manifest, indent=2)) click.echo( f"Manifest written to {output_path}: " f"{len(train_paths)} train, {len(val_paths)} val" ) if __name__ == "__main__": prepare() ================================================ FILE: nobrainer/research/templates/train_bayesian_vnet.py ================================================ """Bayesian VNet training script for autoresearch. The autoresearch loop patches the ``# CONFIG:`` comment line below to update hyperparameters between experiments. On completion, this script writes ``val_dice.json`` in the working directory. Usage ----- python train_bayesian_vnet.py """ # CONFIG: {"learning_rate": 1e-4, "batch_size": 4, "n_epochs": 20, "kl_weight": 1e-4, "dropout_rate": 0.0} # noqa: E501 from __future__ import annotations import json from pathlib import Path from monai.metrics import DiceMetric from monai.utils import set_determinism import torch import torch.optim as optim from nobrainer.dataset import get_dataset from nobrainer.losses import dice as dice_loss_fn from nobrainer.losses import elbo from nobrainer.models.bayesian import BayesianVNet from nobrainer.training import get_device def main() -> None: # ------------------------------------------------------------------ # # Load config from script comment (patched by autoresearch loop) # # ------------------------------------------------------------------ # script_text = Path(__file__).read_text() config: dict = {} for line in script_text.splitlines(): if line.strip().startswith("# CONFIG:"): config = json.loads(line.split("# CONFIG:", 1)[1].strip()) break lr: float = config.get("learning_rate", 1e-4) batch_size: int = int(config.get("batch_size", 4)) n_epochs: int = int(config.get("n_epochs", 20)) kl_weight: float = config.get("kl_weight", 1e-4) set_determinism(seed=42) device = get_device() # ------------------------------------------------------------------ # # Data loading # # ------------------------------------------------------------------ # manifest_path = Path("data_manifest.json") if not manifest_path.exists(): raise FileNotFoundError("data_manifest.json not found. Run prepare.py first.") manifest = json.loads(manifest_path.read_text()) train_images = manifest["train"] val_images = manifest["val"] label_suffix = "_label" # adjust per dataset convention train_labels = [ p.replace(".nii.gz", f"{label_suffix}.nii.gz") for p in train_images ] val_labels = [p.replace(".nii.gz", f"{label_suffix}.nii.gz") for p in val_images] train_loader = get_dataset( image_paths=train_images, label_paths=train_labels, batch_size=batch_size, augment=True, num_workers=0, cache_rate=0.0, ) val_loader = get_dataset( image_paths=val_images, label_paths=val_labels, batch_size=1, num_workers=0, cache_rate=0.0, ) # ------------------------------------------------------------------ # # Model, optimiser, metrics # # ------------------------------------------------------------------ # model = BayesianVNet(n_classes=2, in_channels=1, kl_weight=kl_weight).to(device) optimizer = optim.Adam(model.parameters(), lr=lr) recon_loss_fn = dice_loss_fn(softmax=True) dice_metric = DiceMetric(include_background=False, reduction="mean") # ------------------------------------------------------------------ # # Training loop # # ------------------------------------------------------------------ # import pyro for epoch in range(n_epochs): model.train() for batch in train_loader: imgs = batch["image"].to(device) labels = batch["label"].to(device).long() optimizer.zero_grad() with pyro.poutine.trace(): preds = model(imgs) labels_onehot = torch.zeros_like(preds) labels_onehot.scatter_(1, labels, 1.0) recon = recon_loss_fn(preds, labels_onehot) loss = elbo(model, kl_weight, recon) loss.backward() optimizer.step() # ------------------------------------------------------------------ # # Validation # # ------------------------------------------------------------------ # model.eval() with torch.no_grad(): for batch in val_loader: imgs = batch["image"].to(device) labels = batch["label"].to(device).long() with pyro.poutine.trace(): preds = model(imgs) preds_bin = torch.argmax(preds, dim=1, keepdim=True) dice_metric(preds_bin, labels) val_dice = dice_metric.aggregate().item() dice_metric.reset() # ------------------------------------------------------------------ # # Write val_dice.json # # ------------------------------------------------------------------ # Path("val_dice.json").write_text(json.dumps({"val_dice": val_dice})) print(f"val_dice: {val_dice:.4f}") if __name__ == "__main__": main() ================================================ FILE: nobrainer/slurm.py ================================================ """SLURM utilities for preemptible training with checkpoint/resume. Provides signal handling for SLURM preemption and checkpoint persistence so training jobs can be interrupted and resumed automatically via ``--requeue``. Usage:: from nobrainer.slurm import ( SlurmPreemptionHandler, save_checkpoint, load_checkpoint, ) handler = SlurmPreemptionHandler() start_epoch, metrics = load_checkpoint(ckpt_dir, model, optimizer) for epoch in range(start_epoch, total_epochs): train_one_epoch(...) save_checkpoint(ckpt_dir, model, optimizer, epoch, metrics) if handler.preempted: break # job will be requeued by SLURM """ from __future__ import annotations import json import logging import os from pathlib import Path import signal from typing import Any import torch logger = logging.getLogger(__name__) class SlurmPreemptionHandler: """Handle SLURM preemption signals for graceful checkpoint-and-exit. SLURM sends a configurable signal (default SIGUSR1 via ``--signal``) before killing a preempted job. This handler sets a flag so the training loop can checkpoint and exit cleanly. The ``--requeue`` sbatch flag then re-submits the job. On non-SLURM systems (no ``SLURM_JOB_ID`` environment variable), the handler is still safe to create but will never fire. Parameters ---------- sig : signal.Signals Signal to catch (default ``SIGUSR1``). """ def __init__(self, sig: signal.Signals = signal.SIGUSR1) -> None: self.preempted = False self._sig = sig try: signal.signal(sig, self._handle) logger.info("SLURM preemption handler registered (signal=%s)", sig.name) except (OSError, ValueError): # Signal registration can fail in non-main threads or on Windows logger.debug("Could not register signal handler for %s", sig) def _handle(self, signum: int, frame: Any) -> None: logger.warning( "Received preemption signal %d — will checkpoint and exit", signum ) self.preempted = True @staticmethod def is_slurm_job() -> bool: """Return True if running inside a SLURM job.""" return "SLURM_JOB_ID" in os.environ @staticmethod def slurm_info() -> dict[str, str]: """Return a dict of useful SLURM environment variables.""" keys = [ "SLURM_JOB_ID", "SLURM_JOB_NAME", "SLURM_JOB_PARTITION", "SLURM_NODELIST", "SLURM_NTASKS", "SLURM_GPUS_ON_NODE", "SLURM_RESTART_COUNT", ] return {k: os.environ[k] for k in keys if k in os.environ} def save_checkpoint( checkpoint_dir: str | Path, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int, metrics: dict[str, Any] | None = None, ) -> Path: """Save a resumable training checkpoint. Writes ``checkpoint.pt`` with model weights, optimizer state, epoch, and metrics. Also writes ``checkpoint_meta.json`` for inspection. Parameters ---------- checkpoint_dir : str or Path Directory for checkpoint files. model : torch.nn.Module Model to checkpoint. optimizer : torch.optim.Optimizer Optimizer state to persist. epoch : int Completed epoch number (0-indexed). metrics : dict, optional Accumulated training metrics. Returns ------- Path Path to the written ``checkpoint.pt``. """ checkpoint_dir = Path(checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) ckpt_path = checkpoint_dir / "checkpoint.pt" torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "metrics": metrics or {}, }, ckpt_path, ) meta = { "epoch": epoch, "best_loss": (metrics or {}).get("best_loss"), "train_losses": (metrics or {}).get("train_losses", [])[-3:], } with open(checkpoint_dir / "checkpoint_meta.json", "w") as f: json.dump(meta, f, indent=2, default=str) logger.info("Checkpoint saved: epoch %d → %s", epoch, ckpt_path) return ckpt_path def load_checkpoint( checkpoint_dir: str | Path, model: torch.nn.Module, optimizer: torch.optim.Optimizer | None = None, ) -> tuple[int, dict[str, Any]]: """Load a training checkpoint and return ``(start_epoch, metrics)``. Parameters ---------- checkpoint_dir : str or Path Directory containing ``checkpoint.pt``. model : torch.nn.Module Model to load weights into. optimizer : torch.optim.Optimizer or None Optimizer to restore. If None, only model is loaded. Returns ------- start_epoch : int Next epoch to train (checkpoint epoch + 1). metrics : dict Accumulated metrics from previous training. """ ckpt_path = Path(checkpoint_dir) / "checkpoint.pt" if not ckpt_path.exists(): logger.info("No checkpoint at %s — starting from scratch", ckpt_path) return 0, {} ckpt = torch.load(ckpt_path, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) if optimizer is not None and "optimizer_state_dict" in ckpt: optimizer.load_state_dict(ckpt["optimizer_state_dict"]) start_epoch = ckpt["epoch"] + 1 metrics = ckpt.get("metrics", {}) logger.info( "Resumed from checkpoint: epoch %d, best_loss=%.6f", ckpt["epoch"], metrics.get("best_loss", float("inf")), ) return start_epoch, metrics ================================================ FILE: nobrainer/sr-tests/__init__.py ================================================ ================================================ FILE: nobrainer/sr-tests/conftest.py ================================================ """Shared fixtures for somewhat-realistic tests.""" import pytest from nobrainer.io import read_csv from nobrainer.utils import get_data @pytest.fixture(scope="session") def sample_data(): """Download sample brain data once per test session.""" csv_path = get_data() return read_csv(csv_path) @pytest.fixture(scope="session") def train_eval_split(sample_data): """Split into 9 train + 1 eval.""" return sample_data[:9], sample_data[9] ================================================ FILE: nobrainer/sr-tests/test_bayesian_uncertainty.py ================================================ """Tests for Bayesian segmentation with uncertainty quantification. These tests train a BayesianVNet and run MC prediction on real brain data, which takes 4+ minutes on CPU. They are marked ``@pytest.mark.gpu`` so they only run on the EC2 GPU runner (where they take <30s). The same functionality is also covered by ``test_kwyk_smoke.py``. """ import nibabel as nib import numpy as np import pytest pyro = pytest.importorskip("pyro") # noqa: F841 from nobrainer.processing import Dataset, Segmentation # noqa: E402 @pytest.mark.gpu class TestBayesianUncertainty: """Test Bayesian model produces uncertainty estimates.""" def test_bayesian_predict_returns_tuple(self, train_eval_split, tmp_path): """Bayesian predict with n_samples returns (label, variance, entropy).""" train_data, eval_pair = train_eval_split eval_img_path = eval_pair[0] ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() ) seg = Segmentation("bayesian_vnet") seg.fit(ds, epochs=2) result = seg.predict(eval_img_path, block_shape=(16, 16, 16), n_samples=3) # Should return a tuple of 3 NIfTI images assert isinstance(result, tuple) assert len(result) == 3 label, variance, entropy = result assert isinstance(label, nib.Nifti1Image) assert isinstance(variance, nib.Nifti1Image) assert isinstance(entropy, nib.Nifti1Image) def test_variance_nonzero(self, train_eval_split, tmp_path): """Bayesian model variance should be non-zero.""" train_data, eval_pair = train_eval_split eval_img_path = eval_pair[0] ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() ) seg = Segmentation("bayesian_vnet") seg.fit(ds, epochs=2) _, variance, _ = seg.predict( eval_img_path, block_shape=(16, 16, 16), n_samples=3 ) var_data = np.asarray(variance.dataobj) assert np.any(var_data > 0), "Variance should be non-zero" ================================================ FILE: nobrainer/sr-tests/test_brain_generation.py ================================================ """Tests for brain generation with Progressive GAN.""" import nibabel as nib import numpy as np import pytest import torch from torch.utils.data import DataLoader, TensorDataset pl = pytest.importorskip("pytorch_lightning") # noqa: F841 from nobrainer.processing import Generation # noqa: E402 class TestBrainGeneration: """Test generative model training and image generation.""" def test_generate_returns_nifti_images(self, sample_data): """Generation.fit().generate(2) returns 2 NIfTI images.""" from scipy.ndimage import zoom # Downsample real volumes to 4^3 (GAN needs small, uniform volumes) volumes = [] for img_path, _ in sample_data[:4]: vol = np.asarray(nib.load(img_path).dataobj, dtype=np.float32) vmin, vmax = vol.min(), vol.max() if vmax > vmin: vol = (vol - vmin) / (vmax - vmin) factors = [4 / s for s in vol.shape[:3]] volumes.append(zoom(vol, factors, order=1)) imgs = torch.from_numpy(np.stack(volumes)[:, None]) # (N, 1, 4, 4, 4) loader = DataLoader(TensorDataset(imgs), batch_size=2, shuffle=True) gen = Generation( "progressivegan", model_args={ "latent_size": 16, "fmap_base": 16, "fmap_max": 16, "resolution_schedule": [4], "steps_per_phase": 100, }, ) gen.fit(loader, epochs=50) images = gen.generate(n_images=2) assert len(images) == 2 for img in images: assert isinstance(img, nib.Nifti1Image) assert len(img.shape) >= 3 ================================================ FILE: nobrainer/sr-tests/test_croissant_metadata.py ================================================ """Tests for Croissant-ML metadata generation.""" import json from pathlib import Path from nobrainer.processing import Dataset, Segmentation class TestCroissantMetadata: """Test Croissant-ML provenance in saved models and datasets.""" def test_segmentation_save_croissant_fields(self, train_eval_split, tmp_path): """Segmentation.save() produces croissant.json with provenance fields.""" train_data, _ = train_eval_split ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() ) seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, ) seg.fit(ds, epochs=2) save_dir = tmp_path / "croissant_model" seg.save(save_dir) croissant_path = save_dir / "croissant.json" assert croissant_path.exists() meta = json.loads(croissant_path.read_text()) # Should have Croissant-ML context or provenance has_context = "@context" in meta has_provenance = "nobrainer:provenance" in meta assert ( has_context or has_provenance ), "croissant.json must have @context or nobrainer:provenance" # Check provenance fields if present if has_provenance: prov = meta["nobrainer:provenance"] assert "model_architecture" in prov assert "n_classes" in prov assert prov["model_architecture"] == "unet" assert prov["n_classes"] == 2 def test_dataset_to_croissant(self, train_eval_split, tmp_path): """Dataset.to_croissant() exports dataset metadata.""" train_data, _ = train_eval_split ds = Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) output_path = tmp_path / "dataset_croissant.json" result = ds.to_croissant(output_path) assert Path(result).exists() meta = json.loads(Path(result).read_text()) assert "@context" in meta or "name" in meta ================================================ FILE: nobrainer/sr-tests/test_dataset_builder.py ================================================ """Tests for the fluent Dataset builder with real brain data.""" from nobrainer.processing import Dataset class TestDatasetBuilder: """Test Dataset.from_files() fluent API produces correct outputs.""" def test_from_files_batch_binarize_augment(self, train_eval_split): """Dataset.from_files().batch(2).binarize().augment() produces correct shapes.""" train_data, _ = train_eval_split ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() .augment() ) loader = ds.dataloader batch = next(iter(loader)) assert "image" in batch assert "label" in batch # batch_size=2, 1 channel, block_shape=(16,16,16) assert batch["image"].shape[0] == 2 assert batch["image"].shape[-3:] == (16, 16, 16) assert batch["label"].shape[0] == 2 def test_split_sizes(self, train_eval_split): """Dataset.split() divides data into train/eval with correct sizes.""" train_data, _ = train_eval_split ds = Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) ds_train, ds_eval = ds.split(eval_size=0.2) total = len(train_data) assert len(ds_train.data) + len(ds_eval.data) == total assert len(ds_eval.data) >= 1 def test_streaming_mode_produces_patches(self, train_eval_split): """Dataset.streaming() produces patches via PatchDataset.""" train_data, _ = train_eval_split ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() .streaming(patches_per_volume=2) ) loader = ds.dataloader batch = next(iter(loader)) assert "image" in batch assert batch["image"].shape[-3:] == (16, 16, 16) ================================================ FILE: nobrainer/sr-tests/test_extract_patches.py ================================================ """Tests for extract_patches() with various binarization modes.""" import nibabel as nib import numpy as np import pytest from nobrainer.processing import extract_patches class TestExtractPatches: """Test extract_patches() on real brain data.""" @pytest.fixture() def volume_and_label(self, train_eval_split): """Load first volume and label as numpy arrays.""" train_data, _ = train_eval_split img_path, lbl_path = train_data[0] vol = np.asarray(nib.load(img_path).dataobj, dtype=np.float32) lbl = np.asarray(nib.load(lbl_path).dataobj, dtype=np.float32) return vol, lbl def test_binarize_true(self, volume_and_label): """binarize=True maps any non-zero label to 1.""" vol, lbl = volume_and_label patches = extract_patches( vol, lbl, block_shape=(16, 16, 16), n_patches=5, binarize=True ) assert len(patches) == 5 for img_patch, lbl_patch in patches: assert img_patch.shape == (16, 16, 16) assert lbl_patch.shape == (16, 16, 16) # Only 0 and 1 in binarized labels unique_vals = set(np.unique(lbl_patch)) assert unique_vals <= {0.0, 1.0} def test_binarize_set(self, volume_and_label): """binarize={17, 53} selects hippocampus labels only.""" vol, lbl = volume_and_label patches = extract_patches( vol, lbl, block_shape=(16, 16, 16), n_patches=5, binarize={17, 53} ) for img_patch, lbl_patch in patches: assert img_patch.shape == (16, 16, 16) unique_vals = set(np.unique(lbl_patch)) assert unique_vals <= {0.0, 1.0} def test_binarize_callable(self, volume_and_label): """binarize=lambda applies custom function to label patches.""" vol, lbl = volume_and_label def threshold_fn(x): return (x >= 1000).astype(np.float32) patches = extract_patches( vol, lbl, block_shape=(16, 16, 16), n_patches=5, binarize=threshold_fn ) for img_patch, lbl_patch in patches: assert img_patch.shape == (16, 16, 16) unique_vals = set(np.unique(lbl_patch)) assert unique_vals <= {0.0, 1.0} def test_patch_shapes(self, volume_and_label): """Patches have the requested block_shape.""" vol, lbl = volume_and_label patches = extract_patches(vol, lbl, block_shape=(16, 16, 16), n_patches=3) assert len(patches) == 3 for img_patch, lbl_patch in patches: assert img_patch.shape == (16, 16, 16) assert lbl_patch.shape == (16, 16, 16) ================================================ FILE: nobrainer/sr-tests/test_kwyk_smoke.py ================================================ """Smoke tests for the kwyk reproduction pipeline. Tests train a tiny MeshNet and Bayesian MeshNet for 1 epoch each to verify the end-to-end pipeline works (loss is finite, prediction produces valid NIfTI output, warm-start transfers weights correctly). """ import nibabel as nib import numpy as np import pytest import torch pyro = pytest.importorskip("pyro") from nobrainer.models import get as get_model # noqa: E402 from nobrainer.models.bayesian.warmstart import ( # noqa: E402 warmstart_bayesian_from_deterministic, ) from nobrainer.processing import Dataset, Segmentation # noqa: E402 from nobrainer.training import get_device # noqa: E402 # --------------------------------------------------------------------------- # Shared constants for tiny model # --------------------------------------------------------------------------- FILTERS = 16 BLOCK_SHAPE = (16, 16, 16) N_CLASSES = 2 BATCH_SIZE = 2 MODEL_ARGS = { "n_classes": N_CLASSES, "filters": FILTERS, "receptive_field": 37, "dropout_rate": 0.25, } def _build_dataset(sample_data): """Build a small binarized Dataset from sample_data fixture.""" # Use first 5 volumes pairs = sample_data[:5] ds = ( Dataset.from_files(pairs, block_shape=BLOCK_SHAPE, n_classes=N_CLASSES) .batch(BATCH_SIZE) .binarize() ) return ds def _plot_learning_curve(losses, output_path): """Save a simple learning curve figure.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 fig, ax = plt.subplots(figsize=(6, 4)) ax.plot(range(1, len(losses) + 1), losses, "b-o", markersize=4) ax.set_xlabel("Step") ax.set_ylabel("Loss") ax.set_title("Smoke Test Learning Curve") ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(output_path, dpi=72, bbox_inches="tight") plt.close(fig) @pytest.mark.gpu class TestKwykSmoke: """Smoke tests for the kwyk reproduction pipeline.""" def test_deterministic_meshnet_train(self, sample_data, tmp_path): """Train deterministic MeshNet for 1 epoch; assert loss is finite.""" ds = _build_dataset(sample_data) seg = Segmentation( base_model="meshnet", model_args={k: v for k, v in MODEL_ARGS.items() if k != "n_classes"}, ) # Collect losses via callback losses = [] def _on_epoch(epoch, logs, model): losses.append(logs["loss"] if isinstance(logs, dict) else logs) seg.fit(ds, epochs=1, callbacks=[_on_epoch]) assert len(losses) >= 1, "Expected at least 1 epoch of training" for loss_val in losses: assert np.isfinite(loss_val), f"Loss is not finite: {loss_val}" # Save learning curve _plot_learning_curve(losses, tmp_path / "det_learning_curve.png") assert (tmp_path / "det_learning_curve.png").exists() def test_bayesian_warmstart_train(self, sample_data, tmp_path): """Warm-start BayesianMeshNet from deterministic, train 1 epoch.""" ds = _build_dataset(sample_data) # First train a deterministic model det_model = get_model("meshnet")(**MODEL_ARGS) device = get_device() det_model = det_model.to(device) det_model.train() ce_loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(det_model.parameters(), lr=1e-3) # Quick 1-epoch train of deterministic model loader = ds.dataloader for batch in loader: if isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: images = batch["image"].to(device) labels = batch["label"].to(device) if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer.zero_grad() pred = det_model(images) loss = ce_loss(pred, labels) loss.backward() optimizer.step() break # Just one batch for speed # Build Bayesian model and warm-start (on CPU to avoid device issues) det_model_cpu = det_model.cpu() bayes_model = get_model("bayesian_meshnet")(**MODEL_ARGS) n_transferred = warmstart_bayesian_from_deterministic( bayes_model, det_model_cpu, initial_rho=-3.0 ) assert n_transferred > 0, "Expected at least 1 layer transferred" # Train Bayesian for 1 epoch from nobrainer.models.bayesian.utils import accumulate_kl # Pyro's param store can cache unconstrained tensors on CPU even # after .to(device). Clear and re-register to ensure device consistency. pyro.clear_param_store() bayes_model = bayes_model.to(device) bayes_model.train() optimizer_b = torch.optim.Adam(bayes_model.parameters(), lr=1e-3) losses = [] for batch in loader: if isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: images = batch["image"].to(device) labels = batch["label"].to(device) if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer_b.zero_grad() pred = bayes_model(images) loss = ce_loss(pred, labels) + accumulate_kl(bayes_model) loss.backward() optimizer_b.step() losses.append(loss.item()) break # Just one batch for speed assert len(losses) >= 1 for loss_val in losses: assert np.isfinite(loss_val), f"Bayesian loss not finite: {loss_val}" # Save learning curve _plot_learning_curve(losses, tmp_path / "bayes_learning_curve.png") assert (tmp_path / "bayes_learning_curve.png").exists() def test_predict_output(self, sample_data, tmp_path): """Predict on 1 volume; assert NIfTI output with matching shape.""" ds = _build_dataset(sample_data) seg = Segmentation( base_model="meshnet", model_args={k: v for k, v in MODEL_ARGS.items() if k != "n_classes"}, ) seg.fit(ds, epochs=1) # Predict on first volume eval_img_path = sample_data[0][0] eval_lbl_path = sample_data[0][1] result = seg.predict(eval_img_path, block_shape=BLOCK_SHAPE) # Check output is NIfTI assert isinstance( result, nib.Nifti1Image ), f"Expected Nifti1Image, got {type(result)}" # Check shape matches input spatial dims input_img = nib.load(eval_img_path) input_shape = input_img.shape[:3] result_shape = result.shape[:3] assert ( result_shape == input_shape ), f"Shape mismatch: input={input_shape}, output={result_shape}" # Compute Dice for informational purposes — a 1-epoch model # may produce all-zero predictions, so we don't require Dice > 0. gt_arr = np.asarray(nib.load(eval_lbl_path).dataobj, dtype=np.float32) gt_binary = (gt_arr > 0).astype(np.float32) pred_arr = np.asarray(result.dataobj, dtype=np.float32) pred_binary = (pred_arr > 0).astype(np.float32) intersection = np.logical_and(pred_binary, gt_binary).sum() total = pred_binary.sum() + gt_binary.sum() if total > 0: dice = float(2.0 * intersection / total) else: dice = 1.0 # Dice >= 0 is always true; we just verify the computation doesn't crash. # With more epochs, Dice should improve — this is a smoke test only. assert dice >= 0, f"Expected Dice >= 0, got {dice}" # Save learning curve figure _plot_learning_curve([0.5], tmp_path / "predict_learning_curve.png") assert (tmp_path / "predict_learning_curve.png").exists() ================================================ FILE: nobrainer/sr-tests/test_raw_pytorch_api.py ================================================ """Tests for the raw PyTorch API without the estimator layer.""" import nibabel as nib import torch import nobrainer.models from nobrainer.prediction import predict from nobrainer.training import fit as training_fit class TestRawPyTorchAPI: """Test using raw nobrainer modules directly (no estimator).""" def test_raw_train_predict_cycle(self, train_eval_split, tmp_path): """Train with nobrainer.training.fit, predict with nobrainer.prediction.predict.""" train_data, eval_pair = train_eval_split eval_img_path = eval_pair[0] # Build model directly model_factory = nobrainer.models.get("unet") model = model_factory(n_classes=2, channels=(4, 8), strides=(2,)) # Build dataset directly from nobrainer.dataset import get_dataset image_paths = [pair[0] for pair in train_data] label_paths = [pair[1] for pair in train_data] loader = get_dataset( image_paths=image_paths, label_paths=label_paths, block_shape=(16, 16, 16), batch_size=2, binarize_labels=True, ) # Train optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = torch.nn.CrossEntropyLoss() result = training_fit( model=model, loader=loader, criterion=criterion, optimizer=optimizer, max_epochs=2, gpus=0, ) assert "history" in result assert len(result["history"]) == 2 # Predict model.eval() prediction = predict( inputs=eval_img_path, model=model, block_shape=(16, 16, 16), ) assert isinstance(prediction, nib.Nifti1Image) assert len(prediction.shape) >= 3 ================================================ FILE: nobrainer/sr-tests/test_segmentation_estimator.py ================================================ """Tests for the Segmentation estimator with real brain data.""" import json import nibabel as nib from nobrainer.processing import Dataset, Segmentation class TestSegmentationEstimator: """Test Segmentation estimator fit/predict/save/load cycle.""" def test_fit_predict_returns_nifti(self, train_eval_split, tmp_path): """Segmentation.fit().predict() returns a NIfTI image.""" train_data, eval_pair = train_eval_split eval_img_path = eval_pair[0] ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() ) seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, ) seg.fit(ds, epochs=2) result = seg.predict(eval_img_path, block_shape=(16, 16, 16)) assert isinstance(result, nib.Nifti1Image) assert len(result.shape) >= 3 def test_save_creates_croissant(self, train_eval_split, tmp_path): """Segmentation.save() creates model.pth and croissant.json.""" train_data, _ = train_eval_split ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() ) seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, ) seg.fit(ds, epochs=2) save_dir = tmp_path / "saved_model" seg.save(save_dir) assert (save_dir / "model.pth").exists() assert (save_dir / "croissant.json").exists() meta = json.loads((save_dir / "croissant.json").read_text()) assert "@context" in meta or "nobrainer:provenance" in meta def test_load_roundtrip(self, train_eval_split, tmp_path): """Segmentation.save() then Segmentation.load() restores the model.""" train_data, eval_pair = train_eval_split eval_img_path = eval_pair[0] ds = ( Dataset.from_files( train_data, block_shape=(16, 16, 16), n_classes=2, ) .batch(2) .binarize() ) seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, ) seg.fit(ds, epochs=2) save_dir = tmp_path / "roundtrip_model" seg.save(save_dir) loaded = Segmentation.load(save_dir) result = loaded.predict(eval_img_path, block_shape=(16, 16, 16)) assert isinstance(result, nib.Nifti1Image) ================================================ FILE: nobrainer/sr-tests/test_synthseg_brain.py ================================================ """SR-test: SynthSeg generation from real aparc+aseg label maps. Tests that the enhanced SynthSeg generator produces realistic synthetic images from actual FreeSurfer parcellation data. """ from __future__ import annotations import numpy as np import torch class TestSynthSegBrain: """SynthSeg with real brain data.""" def test_generate_from_sample_data(self, sample_data): """Generate synthetic image from real aparc+aseg label map.""" from nobrainer.augmentation.synthseg import SynthSegGenerator # sample_data is list of (image, label) tuples label_paths = [p[1] for p in sample_data[:2]] gen = SynthSegGenerator( label_paths, n_samples_per_map=2, elastic_std=2.0, # mild deformation for speed rotation_range=10.0, randomize_resolution=False, # skip for speed ) sample = gen[0] assert sample["image"].shape[0] == 1 # channel dim assert sample["label"].shape[0] == 1 assert sample["image"].dtype == torch.float32 assert sample["label"].dtype == torch.int64 # Image should have non-zero values in brain region img = sample["image"][0].numpy() lbl = sample["label"][0].numpy() brain_mask = lbl > 0 assert brain_mask.sum() > 0 assert img[brain_mask].std() > 0 # not constant def test_two_samples_differ(self, sample_data): """Two samples from same label map should differ.""" from nobrainer.augmentation.synthseg import SynthSegGenerator label_paths = [p[1] for p in sample_data[:1]] gen = SynthSegGenerator( label_paths, n_samples_per_map=2, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=False, ) s1 = gen[0]["image"] s2 = gen[1]["image"] assert not torch.allclose(s1, s2) def test_label_structure_preserved(self, sample_data): """Spatial augmentation should preserve label topology.""" from nobrainer.augmentation.synthseg import SynthSegGenerator label_paths = [p[1] for p in sample_data[:1]] gen = SynthSegGenerator( label_paths, n_samples_per_map=1, elastic_std=2.0, rotation_range=5.0, ) sample = gen[0] lbl = sample["label"][0].numpy() # Should still have brain structure (not all zeros or all one label) unique = np.unique(lbl) assert len(unique) > 2 # at least background + 2 regions ================================================ FILE: nobrainer/sr-tests/test_zarr_conversion.py ================================================ """Tests for NIfTI-to-Zarr and Zarr-to-NIfTI conversion.""" from pathlib import Path import nibabel as nib import numpy as np import pytest zarr = pytest.importorskip("zarr") # noqa: F841 from nobrainer.io import nifti_to_zarr, zarr_to_nifti # noqa: E402 def _mgz_to_nifti(mgz_path: str, output_dir: Path) -> Path: """Convert .mgz to .nii.gz (niizarr doesn't support MGH).""" img = nib.load(mgz_path) out = output_dir / (Path(mgz_path).stem + ".nii.gz") nib.save(nib.Nifti1Image(np.asarray(img.dataobj), img.affine), str(out)) return out class TestZarrConversion: """Test Zarr round-trip conversion on real brain data.""" def test_nifti_to_zarr(self, train_eval_split, tmp_path): """nifti_to_zarr() creates a valid Zarr store from a real volume.""" train_data, _ = train_eval_split mgz_path = train_data[0][0] nii_path = _mgz_to_nifti(mgz_path, tmp_path) zarr_path = tmp_path / "brain.zarr" result = nifti_to_zarr(nii_path, zarr_path, chunk_shape=(16, 16, 16), levels=1) assert Path(result).exists() import zarr as zarr_mod store = zarr_mod.open_group(str(zarr_path), mode="r") assert "0" in store arr = np.asarray(store["0"]) assert arr.ndim == 3 def test_zarr_to_nifti_roundtrip(self, train_eval_split, tmp_path): """zarr_to_nifti() round-trips back to NIfTI with matching shape.""" train_data, _ = train_eval_split mgz_path = train_data[0][0] nii_path = _mgz_to_nifti(mgz_path, tmp_path) zarr_path = tmp_path / "roundtrip.zarr" nifti_to_zarr(nii_path, zarr_path, chunk_shape=(16, 16, 16), levels=1) roundtrip_path = tmp_path / "roundtrip.nii.gz" zarr_to_nifti(zarr_path, roundtrip_path) original = nib.load(str(nii_path)) roundtrip = nib.load(str(roundtrip_path)) assert original.shape == roundtrip.shape # Value range should be preserved (exact match may differ due to # niizarr orientation transforms) orig_data = np.asarray(original.dataobj, dtype=np.float32) rt_data = np.asarray(roundtrip.dataobj, dtype=np.float32) assert abs(orig_data.mean() - rt_data.mean()) < orig_data.std() * 0.5 def test_multi_resolution_pyramid(self, train_eval_split, tmp_path): """nifti_to_zarr(levels=3) creates a multi-resolution pyramid.""" train_data, _ = train_eval_split mgz_path = train_data[0][0] nii_path = _mgz_to_nifti(mgz_path, tmp_path) zarr_path = tmp_path / "pyramid.zarr" nifti_to_zarr(nii_path, zarr_path, chunk_shape=(16, 16, 16), levels=3) import zarr as zarr_mod store = zarr_mod.open_group(str(zarr_path), mode="r") # Should have levels 0, 1, 2 assert "0" in store assert "1" in store assert "2" in store shape_0 = np.asarray(store["0"]).shape shape_1 = np.asarray(store["1"]).shape shape_2 = np.asarray(store["2"]).shape # Each level should be roughly half the previous for dim in range(3): assert shape_1[dim] <= shape_0[dim] assert shape_2[dim] <= shape_1[dim] ================================================ FILE: nobrainer/sr-tests/test_zarr_pipeline.py ================================================ """SR-test: end-to-end Zarr pipeline with real brain data. Converts sample brain data to Zarr, creates partition, builds Dataset.from_zarr(), and verifies the DataLoader yields correct patches. """ from __future__ import annotations import numpy as np from nobrainer.processing import Dataset class TestZarrPipeline: """End-to-end Zarr store → partition → Dataset → DataLoader.""" def test_zarr_store_from_sample_data(self, sample_data, tmp_path): """Convert sample data to Zarr, create partition, load via Dataset.""" from nobrainer.datasets.zarr_store import ( create_partition, create_zarr_store, store_info, ) # Use first 5 subjects pairs = sample_data[:5] # Create Zarr store (auto-conform since shapes may differ) store_path = create_zarr_store( pairs, tmp_path / "brain.zarr", conform=True, ) # Verify store metadata info = store_info(store_path) assert info["n_subjects"] == 5 assert info["layout"] == "stacked" assert info["conformed"] is True # Create partition part_path = create_partition(store_path, ratios=(60, 20, 20), seed=42) # Build Dataset from Zarr with partition ds = Dataset.from_zarr( store_path, block_shape=(16, 16, 16), n_classes=2, partition="train", partition_path=part_path, ) # Verify data list is filtered assert len(ds.data) == 3 # 60% of 5 = 3 # Verify Zarr metadata in data entries assert "_zarr_index" in ds.data[0] assert "_subject_id" in ds.data[0] def test_zarr_store_roundtrip(self, sample_data, tmp_path): """Verify Zarr store preserves data fidelity.""" import zarr from nobrainer.datasets.zarr_store import create_zarr_store pairs = sample_data[:2] store_path = create_zarr_store(pairs, tmp_path / "brain.zarr", conform=True) store = zarr.open_group(str(store_path), mode="r") assert store["images"].shape[0] == 2 assert store["labels"].shape[0] == 2 # Images should be float32, labels int32 assert store["images"].dtype == np.float32 assert store["labels"].dtype == np.int32 ================================================ FILE: nobrainer/tests/__init__.py ================================================ ================================================ FILE: nobrainer/tests/contract/__init__.py ================================================ ================================================ FILE: nobrainer/tests/contract/test_cli.py ================================================ """CLI contract tests for nobrainer commands. Verifies that all CLI commands advertised in contracts/nobrainer-pytorch-api.md are present, have the expected options, and exit with code 0 on --help. """ from __future__ import annotations import subprocess import sys def _help(cmd: list[str]) -> str: """Run `nobrainer --help` and return stdout.""" result = subprocess.run( [sys.executable, "-m", "nobrainer.cli.main"] + cmd + ["--help"], capture_output=True, text=True, ) assert ( result.returncode == 0 ), f"'{' '.join(cmd)} --help' exited {result.returncode}:\n{result.stderr}" return result.stdout class TestPredictCommand: def test_predict_help_exits_zero(self): _help(["predict"]) def test_predict_has_model_option(self): out = _help(["predict"]) assert "--model" in out or "-m" in out def test_predict_has_model_type_option(self): out = _help(["predict"]) assert "--model-type" in out def test_predict_has_n_classes_option(self): out = _help(["predict"]) assert "--n-classes" in out def test_predict_has_device_option(self): out = _help(["predict"]) assert "--device" in out def test_predict_has_n_samples_option(self): out = _help(["predict"]) assert "--n-samples" in out class TestGenerateCommand: def test_generate_help_exits_zero(self): _help(["generate"]) def test_generate_has_model_option(self): out = _help(["generate"]) assert "--model" in out or "-m" in out def test_generate_has_model_type_option(self): out = _help(["generate"]) assert "--model-type" in out def test_generate_has_n_samples_option(self): out = _help(["generate"]) assert "--n-samples" in out def test_generate_has_latent_size_option(self): out = _help(["generate"]) assert "--latent-size" in out class TestConvertTfrecordsCommand: def test_convert_tfrecords_help_exits_zero(self): _help(["convert-tfrecords"]) def test_convert_tfrecords_has_input_option(self): out = _help(["convert-tfrecords"]) assert "--input" in out or "-i" in out def test_convert_tfrecords_has_output_dir_option(self): out = _help(["convert-tfrecords"]) assert "--output-dir" in out class TestResearchCommand: def test_research_help_exits_zero(self): _help(["research"]) def test_research_has_working_dir_option(self): out = _help(["research"]) assert "--working-dir" in out def test_research_has_max_experiments_option(self): out = _help(["research"]) assert "--max-experiments" in out def test_research_has_budget_hours_option(self): out = _help(["research"]) assert "--budget-hours" in out class TestCommitCommand: def test_commit_help_exits_zero(self): _help(["commit"]) def test_commit_has_model_path_option(self): out = _help(["commit"]) assert "--model-path" in out def test_commit_has_config_path_option(self): out = _help(["commit"]) assert "--config-path" in out def test_commit_has_val_dice_option(self): out = _help(["commit"]) assert "--val-dice" in out class TestInfoCommand: def test_info_help_exits_zero(self): _help(["info"]) ================================================ FILE: nobrainer/tests/gpu/__init__.py ================================================ ================================================ FILE: nobrainer/tests/gpu/test_bayesian_e2e.py ================================================ """GPU end-to-end test: Bayesian VNet with uncertainty quantification. T045 — US2 acceptance scenario: predict_with_uncertainty() produces label, variance, and entropy maps. Variance and entropy are non-zero. Bayesian model trained via overfit on synthetic sphere data achieves Dice >= 0.90 (lower than deterministic due to stochastic inference). """ from __future__ import annotations import numpy as np import pytest import torch import torch.nn as nn from nobrainer.models.bayesian import BayesianVNet from nobrainer.prediction import predict_with_uncertainty def _make_sphere_volume(shape=(64, 64, 64), radius=20): """Create a synthetic volume with a centered sphere as the label.""" vol = np.random.rand(*shape).astype(np.float32) * 0.3 label = np.zeros(shape, dtype=np.float32) center = np.array(shape) / 2 coords = np.mgrid[: shape[0], : shape[1], : shape[2]] dist = np.sqrt(sum((c - ctr) ** 2 for c, ctr in zip(coords, center))) mask = dist < radius label[mask] = 1.0 vol[mask] += 0.7 return vol, label @pytest.mark.gpu class TestBayesianEndToEnd: def test_bayesian_vnet_overfit_with_uncertainty(self): """Train 2-class BayesianVNet, run MC inference, check Dice and uncertainty.""" device = torch.device("cuda") torch.manual_seed(42) vol, label = _make_sphere_volume(shape=(64, 64, 64), radius=20) x = torch.from_numpy(vol[None, None]).to(device) label_long = torch.from_numpy(label).long().to(device) # Use n_classes=2 so softmax produces meaningful probabilities model = BayesianVNet( in_channels=1, n_classes=2, prior_type="standard_normal" ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() # Overfit model.train() for _ in range(200): optimizer.zero_grad() pred = model(x) loss = criterion(pred, label_long.unsqueeze(0)) loss.backward() optimizer.step() # Run MC inference with uncertainty label_img, var_img, entropy_img = predict_with_uncertainty( inputs=vol, model=model, n_samples=10, block_shape=(32, 32, 32), batch_size=8, device="cuda", ) # Check shapes assert label_img.shape == (64, 64, 64) assert var_img.shape == (64, 64, 64) assert entropy_img.shape == (64, 64, 64) # Variance and entropy should be non-zero (stochastic model) var_data = np.asarray(var_img.dataobj) entropy_data = np.asarray(entropy_img.dataobj) assert var_data.sum() > 0, "Variance map is all zeros" assert entropy_data.sum() > 0, "Entropy map is all zeros" # Dice check for class 1 (>= 0.90, relaxed for Bayesian stochasticity) pred_arr = np.asarray(label_img.dataobj) pred_bin = (pred_arr == 1).astype(np.float32) intersection = (pred_bin * label).sum() dice = 2 * intersection / (pred_bin.sum() + label.sum() + 1e-8) assert dice >= 0.90, f"Bayesian Dice {dice:.4f} < 0.90 threshold" ================================================ FILE: nobrainer/tests/gpu/test_gan_e2e.py ================================================ """GPU end-to-end test: ProgressiveGAN training. T054 — US3 acceptance scenario: ProgressiveGAN completes extended training on synthetic 3D volumes without NaN in losses. Generated output has correct shape and non-trivial intensity distribution. """ from __future__ import annotations import numpy as np import pytest import pytorch_lightning as pl import torch from torch.utils.data import DataLoader, TensorDataset from nobrainer.models.generative import ProgressiveGAN def _make_loader(n_samples=64, spatial=4, batch_size=4): """Create a DataLoader with enough data for extended training.""" imgs = torch.randn(n_samples, 1, spatial, spatial, spatial) return DataLoader(TensorDataset(imgs), batch_size=batch_size, shuffle=True) @pytest.mark.gpu class TestProgressiveGANEndToEnd: def test_extended_training_no_nan(self): """Train ProgressiveGAN for many steps; verify no NaN in discriminator.""" torch.manual_seed(42) loader = _make_loader(n_samples=64, spatial=4, batch_size=4) model = ProgressiveGAN( latent_size=32, fmap_base=32, fmap_max=32, resolution_schedule=[4], steps_per_phase=2000, ) trainer = pl.Trainer( max_steps=500, accelerator="gpu", devices=1, enable_checkpointing=False, logger=False, enable_progress_bar=False, ) trainer.fit(model, loader) # Verify discriminator outputs are finite after training model.eval() with torch.no_grad(): x_real = next(iter(loader))[0].to(model.device) z = torch.randn(x_real.size(0), 32, device=model.device) x_fake = model.generator(z) d_real = model.discriminator(x_real) d_fake = model.discriminator(x_fake) assert torch.isfinite(d_real).all(), "d_real contains NaN/Inf" assert torch.isfinite(d_fake).all(), "d_fake contains NaN/Inf" assert not torch.isnan(x_fake).any(), "Generated volumes contain NaN" def test_generated_output_shape(self): """After training, generated volumes have correct shape.""" torch.manual_seed(42) loader = _make_loader(n_samples=32, spatial=4, batch_size=4) model = ProgressiveGAN( latent_size=32, fmap_base=32, fmap_max=32, resolution_schedule=[4], steps_per_phase=500, ) trainer = pl.Trainer( max_steps=100, accelerator="gpu", devices=1, enable_checkpointing=False, logger=False, enable_progress_bar=False, ) trainer.fit(model, loader) model.eval() model.generator.current_level = 0 model.generator.alpha = 1.0 with torch.no_grad(): z = torch.randn(4, 32, device=model.device) generated = model.generator(z) # Check shape: (4, 1, 4, 4, 4) assert generated.shape == ( 4, 1, 4, 4, 4, ), f"Expected (4, 1, 4, 4, 4), got {generated.shape}" assert not np.isnan(generated.cpu().numpy()).any(), "NaN in generated" ================================================ FILE: nobrainer/tests/gpu/test_multi_gpu.py ================================================ """GPU integration test: multi-GPU training and inference. T035 — US4: requires 2+ GPUs. Tests DDP training speedup and multi-GPU predict() correctness. """ from __future__ import annotations import time import numpy as np import pytest import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from nobrainer.models.segmentation import unet from nobrainer.prediction import predict from nobrainer.training import fit @pytest.mark.gpu @pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Requires 2+ GPUs for multi-GPU tests", ) class TestMultiGPU: def test_ddp_fit_loss_decreases(self): """fit() with gpus=2 produces decreasing loss.""" torch.manual_seed(42) x = torch.randn(16, 1, 16, 16, 16) y = torch.randint(0, 2, (16, 16, 16, 16)) loader = DataLoader(TensorDataset(x, y), batch_size=4) model = nn.Sequential( nn.Conv3d(1, 8, 3, padding=1), nn.ReLU(), nn.Conv3d(8, 2, 1) ) losses = [] def track(epoch, loss, model): losses.append(loss) result = fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters(), lr=1e-2), max_epochs=10, gpus=2, callbacks=[track], ) final_loss = result["history"][-1]["loss"] assert final_loss < losses[0] def test_multi_gpu_predict_matches_single(self): """Multi-GPU predict() output matches single-GPU result.""" torch.manual_seed(42) vol = np.random.rand(32, 32, 32).astype(np.float32) model = unet(n_classes=2) # Single GPU result_single = predict( inputs=vol, model=model, block_shape=(16, 16, 16), device="cuda:0", ) # Multi GPU (auto-distributes) result_multi = predict( inputs=vol, model=model, block_shape=(16, 16, 16), device="cuda", ) single_arr = np.asarray(result_single.dataobj) multi_arr = np.asarray(result_multi.dataobj) assert np.array_equal(single_arr, multi_arr) def test_ddp_speedup(self): """2-GPU training achieves >=1.3x speedup vs 1 GPU.""" torch.manual_seed(42) x = torch.randn(32, 1, 16, 16, 16) y = torch.randint(0, 2, (32, 16, 16, 16)) loader = DataLoader(TensorDataset(x, y), batch_size=4) model = nn.Sequential( nn.Conv3d(1, 16, 3, padding=1), nn.ReLU(), nn.Conv3d(16, 2, 1) ) # Time single GPU t0 = time.time() fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters()), max_epochs=5, gpus=1, ) single_time = time.time() - t0 # Time 2 GPUs t0 = time.time() fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters()), max_epochs=5, gpus=2, ) multi_time = time.time() - t0 speedup = single_time / multi_time print( f"Speedup: {speedup:.2f}x (single={single_time:.1f}s, multi={multi_time:.1f}s)" ) assert speedup >= 1.3, f"Speedup {speedup:.2f}x < 1.3x threshold" ================================================ FILE: nobrainer/tests/gpu/test_predict_e2e.py ================================================ """GPU end-to-end test: train a UNet on synthetic data, then verify predict() produces high Dice on the same data (overfitting test). T031 — US1 acceptance scenario 2: Dice >= 0.95 on a known volume. Since we don't ship reference weights in the repo, this test creates a synthetic brain-like volume (sphere label), trains a UNet to overfit it, then runs predict() and checks the Dice score. """ from __future__ import annotations import nibabel as nib import numpy as np import pytest import torch import torch.nn as nn from nobrainer.models.segmentation import unet from nobrainer.prediction import predict def _make_sphere_volume(shape=(64, 64, 64), radius=20): """Create a synthetic volume with a centered sphere as the label.""" vol = np.random.rand(*shape).astype(np.float32) * 0.3 label = np.zeros(shape, dtype=np.float32) center = np.array(shape) / 2 coords = np.mgrid[: shape[0], : shape[1], : shape[2]] dist = np.sqrt(sum((c - ctr) ** 2 for c, ctr in zip(coords, center))) mask = dist < radius label[mask] = 1.0 vol[mask] += 0.7 # make sphere brighter return vol, label @pytest.mark.gpu class TestPredictEndToEnd: def test_unet_overfit_dice_above_threshold(self): """Train 2-class UNet to overfit a sphere, then check Dice >= 0.95.""" device = torch.device("cuda") torch.manual_seed(42) vol, label = _make_sphere_volume(shape=(64, 64, 64), radius=20) x = torch.from_numpy(vol[None, None]).to(device) # (1, 1, 64, 64, 64) # One-hot encode label for 2-class: background + foreground label_long = torch.from_numpy(label).long().to(device) # (64, 64, 64) y_onehot = nn.functional.one_hot(label_long, 2) # (64,64,64,2) y_onehot = y_onehot.permute(3, 0, 1, 2).unsqueeze(0).float() # (1,2,64,64,64) model = unet(n_classes=2).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() # Overfit on a single sample model.train() for _ in range(200): optimizer.zero_grad() pred = model(x) # (1, 2, 64, 64, 64) loss = criterion(pred, label_long.unsqueeze(0)) loss.backward() optimizer.step() # Run predict() on the same volume — returns argmax labels model.eval() result = predict( inputs=vol, model=model, block_shape=(32, 32, 32), batch_size=8, device="cuda", return_labels=True, ) pred_arr = np.asarray(result.dataobj) # Compute Dice for class 1 (foreground) pred_bin = (pred_arr == 1).astype(np.float32) intersection = (pred_bin * label).sum() dice = 2 * intersection / (pred_bin.sum() + label.sum() + 1e-8) assert dice >= 0.95, f"Dice {dice:.4f} < 0.95 threshold" def test_predict_output_is_nifti_on_gpu(self): """Verify predict() returns a NIfTI image when run on GPU.""" vol, _ = _make_sphere_volume(shape=(32, 32, 32)) model = unet(n_classes=2) result = predict( inputs=vol, model=model, block_shape=(32, 32, 32), batch_size=1, device="cuda", ) assert isinstance(result, nib.Nifti1Image) assert result.shape == (32, 32, 32) ================================================ FILE: nobrainer/tests/integration/__init__.py ================================================ ================================================ FILE: nobrainer/tests/integration/test_datalad_commit.py ================================================ """Integration test for commit_best_model with a real DataLad dataset. Requirements: datalad>=0.19 and git-annex must be installed. No OSF remote is configured — OSF push is skipped gracefully. The 1-hour SC-008 SLA for OSF retrieval requires live OSF and is not validated here (manual verification only). """ from __future__ import annotations import json from pathlib import Path import subprocess import pytest import torch datalad = pytest.importorskip("datalad", reason="datalad not installed") @pytest.fixture() def trained_models_dataset(tmp_path): """Create a fresh DataLad dataset in tmp_path/trained_models.""" import datalad.api as dl trained_models = tmp_path / "trained_models" trained_models.mkdir() dl.create(path=str(trained_models)) return trained_models @pytest.fixture() def model_files(tmp_path): """Create dummy model.pth and config.json files.""" run_dir = tmp_path / "run" run_dir.mkdir() model_path = run_dir / "best_model.pth" torch.save({"weights": torch.randn(4, 4)}, str(model_path)) config_path = run_dir / "best_config.json" config_path.write_text(json.dumps({"learning_rate": 1e-4, "batch_size": 4})) return model_path, config_path class TestCommitBestModelIntegration: def test_files_committed_to_datalad(self, trained_models_dataset, model_files): """commit_best_model creates model.pth, config.json, model_card.md in dataset.""" from nobrainer.research.loop import commit_best_model model_path, config_path = model_files result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models_dataset, model_family="bayesian_vnet", val_dice=0.87, source_run_id="integration_test_001", ) dest = Path(result["path"]) assert (dest / "model.pth").exists() assert (dest / "config.json").exists() assert (dest / "model_card.md").exists() def test_datalad_dataset_is_clean_after_commit( self, trained_models_dataset, model_files ): """datalad status shows no untracked/modified files after commit_best_model.""" import datalad.api as dl from nobrainer.research.loop import commit_best_model model_path, config_path = model_files commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models_dataset, model_family="bayesian_vnet", val_dice=0.87, ) status_results = list(dl.status(dataset=str(trained_models_dataset))) unclean = [r for r in status_results if r.get("state") not in ("clean", None)] assert len(unclean) == 0, f"Expected clean dataset, got: {unclean}" def test_git_log_contains_commit_message(self, trained_models_dataset, model_files): """Git log in DataLad dataset contains the autoresearch commit.""" from nobrainer.research.loop import commit_best_model model_path, config_path = model_files result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models_dataset, model_family="bayesian_vnet", val_dice=0.88, ) git_log = subprocess.run( ["git", "log", "--oneline", "-5"], cwd=str(trained_models_dataset), capture_output=True, text=True, check=True, ) assert "bayesian_vnet" in git_log.stdout assert "0.8800" in git_log.stdout assert result["datalad_commit"] in git_log.stdout def test_directory_structure_follows_convention( self, trained_models_dataset, model_files ): """Model files land under neuronets/autoresearch///.""" from nobrainer.research.loop import commit_best_model model_path, config_path = model_files result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models_dataset, model_family="bayesian_vnet", val_dice=0.90, ) dest = Path(result["path"]) # Path must be: /neuronets/autoresearch/bayesian_vnet/ parts = dest.parts assert "neuronets" in parts assert "autoresearch" in parts assert "bayesian_vnet" in parts def test_model_card_contains_required_metadata( self, trained_models_dataset, model_files ): """model_card.md includes model family, val_dice, source_run_id, and versions.""" from nobrainer.research.loop import commit_best_model model_path, config_path = model_files result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models_dataset, model_family="bayesian_vnet", val_dice=0.85, source_run_id="run_abc123", ) card = (Path(result["path"]) / "model_card.md").read_text() assert "bayesian_vnet" in card assert "0.8500" in card assert "run_abc123" in card assert "PyTorch" in card def test_osf_push_skipped_gracefully_when_no_remote( self, trained_models_dataset, model_files ): """No OSF remote configured — osf_url is None, function completes normally.""" from nobrainer.research.loop import commit_best_model model_path, config_path = model_files result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models_dataset, model_family="bayesian_vnet", val_dice=0.80, ) # Without an OSF remote, push fails gracefully and osf_url is None assert result["osf_url"] is None ================================================ FILE: nobrainer/tests/integration/test_research_smoke.py ================================================ """Integration test for autoresearch loop with budget-minutes constraint. T014: Run the full research loop with a 60-second budget and verify it produces a run_summary.md with at least 1 experiment entry. """ from __future__ import annotations from nobrainer.research.loop import run_loop class TestResearchSmoke: def test_research_loop_completes_with_budget_seconds(self, tmp_path): """Full research loop with 60s budget, tiny MeshNet on synthetic data.""" # Create a minimal train script that writes val_dice.json quickly train_script = tmp_path / "train.py" train_script.write_text( "import json, random, time\n" "time.sleep(0.5)\n" 'json.dump({"val_dice": round(random.uniform(0.4, 0.9), 4)}, ' 'open("val_dice.json", "w"))\n' ) results = run_loop( working_dir=tmp_path, model_family="meshnet", max_experiments=2, budget_seconds=60, ) # Verify at least 1 experiment ran assert len(results) >= 1, "Expected at least 1 experiment" # Verify run_summary.md exists summary = tmp_path / "run_summary.md" assert summary.exists(), "run_summary.md not created" content = summary.read_text() assert "val_dice" in content.lower() or "experiment" in content.lower() ================================================ FILE: nobrainer/tests/unit/__init__.py ================================================ ================================================ FILE: nobrainer/tests/unit/test_bayesian_layers.py ================================================ """Unit tests for BayesianConv3d, BayesianLinear, and accumulate_kl.""" from __future__ import annotations import pyro import pytest import torch from nobrainer.models.bayesian.layers import BayesianConv3d, BayesianLinear from nobrainer.models.bayesian.utils import accumulate_kl # --------------------------------------------------------------------------- # BayesianConv3d # --------------------------------------------------------------------------- class TestBayesianConv3d: def setup_method(self): pyro.clear_param_store() def _forward(self, layer, x): """Run one forward pass inside a pyro.poutine.trace context.""" with pyro.poutine.trace(): return layer(x) def test_output_shape(self): layer = BayesianConv3d(1, 4, kernel_size=3, padding=1) x = torch.zeros(2, 1, 8, 8, 8) out = self._forward(layer, x) assert out.shape == (2, 4, 8, 8, 8) def test_kl_populated_after_forward(self): layer = BayesianConv3d(1, 4, kernel_size=3, padding=1) x = torch.zeros(2, 1, 8, 8, 8) self._forward(layer, x) assert isinstance(layer.kl, torch.Tensor) assert layer.kl.numel() == 1 def test_kl_positive(self): layer = BayesianConv3d(1, 4, kernel_size=3, padding=1) x = torch.zeros(2, 1, 8, 8, 8) self._forward(layer, x) assert layer.kl.item() > 0 def test_kl_varies_across_samples(self): """KL should differ between two forward passes (stochastic weights).""" layer = BayesianConv3d(1, 4, kernel_size=3, padding=1) x = torch.zeros(2, 1, 8, 8, 8) self._forward(layer, x) kl1 = layer.kl.item() self._forward(layer, x) kl2 = layer.kl.item() # They may occasionally be equal, but should usually differ assert kl1 == pytest.approx(kl2, rel=1.0) or kl1 != kl2 def test_prior_laplace(self): layer = BayesianConv3d(1, 4, kernel_size=3, padding=1, prior_type="laplace") x = torch.zeros(2, 1, 8, 8, 8) self._forward(layer, x) assert layer.kl.item() > 0 def test_prior_spike_and_slab(self): layer = BayesianConv3d( 1, 4, kernel_size=3, padding=1, prior_type="spike_and_slab" ) x = torch.zeros(2, 1, 8, 8, 8) out = self._forward(layer, x) assert out.shape == (2, 4, 8, 8, 8) assert isinstance(layer.kl, torch.Tensor) assert torch.isfinite(layer.kl) # Check that z_logit parameter exists assert hasattr(layer, "z_logit") def test_no_bias(self): layer = BayesianConv3d(1, 4, kernel_size=3, padding=1, bias=False) assert layer.bias_mu is None assert layer.bias_rho is None x = torch.zeros(2, 1, 8, 8, 8) self._forward(layer, x) assert layer.kl.item() > 0 def test_weight_sigma_positive(self): layer = BayesianConv3d(1, 4, kernel_size=3) assert (layer.weight_sigma > 0).all() # --------------------------------------------------------------------------- # BayesianLinear # --------------------------------------------------------------------------- class TestBayesianLinear: def setup_method(self): pyro.clear_param_store() def _forward(self, layer, x): with pyro.poutine.trace(): return layer(x) def test_output_shape(self): layer = BayesianLinear(16, 8) x = torch.zeros(4, 16) out = self._forward(layer, x) assert out.shape == (4, 8) def test_kl_populated(self): layer = BayesianLinear(16, 8) x = torch.zeros(4, 16) self._forward(layer, x) assert layer.kl.item() > 0 def test_no_bias(self): layer = BayesianLinear(16, 8, bias=False) assert layer.bias_mu is None x = torch.zeros(4, 16) self._forward(layer, x) assert layer.kl.item() > 0 def test_prior_laplace(self): layer = BayesianLinear(16, 8, prior_type="laplace") x = torch.zeros(4, 16) self._forward(layer, x) assert layer.kl.item() > 0 def test_prior_spike_and_slab(self): layer = BayesianLinear(16, 8, prior_type="spike_and_slab") x = torch.zeros(4, 16) out = self._forward(layer, x) assert out.shape == (4, 8) assert torch.isfinite(layer.kl) assert hasattr(layer, "z_logit") # --------------------------------------------------------------------------- # accumulate_kl # --------------------------------------------------------------------------- class TestAccumulateKl: def setup_method(self): pyro.clear_param_store() def test_single_layer(self): layer = BayesianConv3d(1, 4, kernel_size=3, padding=1) x = torch.zeros(2, 1, 8, 8, 8) with pyro.poutine.trace(): layer(x) kl = accumulate_kl(layer) assert kl.item() == pytest.approx(layer.kl.item()) def test_multiple_layers(self): from pyro.nn import PyroModule class _TwoConv(PyroModule): def __init__(self): super().__init__() self.l1 = BayesianConv3d(1, 4, kernel_size=3, padding=1) self.l2 = BayesianConv3d(4, 8, kernel_size=3, padding=1) def forward(self, x): return self.l2(self.l1(x)) model = _TwoConv() x = torch.zeros(2, 1, 8, 8, 8) with pyro.poutine.trace(): model(x) total = accumulate_kl(model) expected = model.l1.kl + model.l2.kl assert total.item() == pytest.approx(expected.item(), rel=1e-5) def test_non_bayesian_model_returns_zero(self): import torch.nn as nn model = nn.Sequential(nn.Conv3d(1, 4, 3, padding=1)) kl = accumulate_kl(model) assert kl.item() == 0.0 ================================================ FILE: nobrainer/tests/unit/test_bayesian_models.py ================================================ """Unit tests for BayesianVNet and BayesianMeshNet.""" from __future__ import annotations import pyro import pytest import torch from nobrainer.models.bayesian import ( BayesianMeshNet, BayesianVNet, accumulate_kl, bayesian_meshnet, bayesian_vnet, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _run(model, x): """Forward pass inside a Pyro trace context.""" with pyro.poutine.trace(): return model(x) # --------------------------------------------------------------------------- # BayesianVNet # --------------------------------------------------------------------------- class TestBayesianVNet: def setup_method(self): pyro.clear_param_store() def test_default_construction(self): m = BayesianVNet() assert isinstance(m, BayesianVNet) def test_output_shape_single_class(self): m = BayesianVNet(n_classes=1, in_channels=1, base_filters=8, levels=2) x = torch.zeros(2, 1, 16, 16, 16) out = _run(m, x) assert out.shape == (2, 1, 16, 16, 16) def test_output_shape_multi_class(self): m = BayesianVNet(n_classes=4, in_channels=1, base_filters=8, levels=2) x = torch.zeros(2, 1, 16, 16, 16) out = _run(m, x) assert out.shape == (2, 4, 16, 16, 16) def test_kl_accumulated(self): m = BayesianVNet(n_classes=1, in_channels=1, base_filters=8, levels=2) x = torch.zeros(2, 1, 16, 16, 16) _run(m, x) kl = accumulate_kl(m) assert kl.item() > 0 def test_factory_function(self): m = bayesian_vnet(n_classes=2, in_channels=1, base_filters=8, levels=2) assert isinstance(m, BayesianVNet) def test_laplace_prior(self): m = BayesianVNet( n_classes=1, in_channels=1, base_filters=8, levels=2, prior_type="laplace" ) x = torch.zeros(2, 1, 16, 16, 16) _run(m, x) assert accumulate_kl(m).item() > 0 def test_kl_weight_attribute(self): m = BayesianVNet(kl_weight=0.001) assert m.kl_weight == pytest.approx(0.001) # --------------------------------------------------------------------------- # BayesianMeshNet # --------------------------------------------------------------------------- class TestBayesianMeshNet: def setup_method(self): pyro.clear_param_store() def test_default_construction(self): m = BayesianMeshNet() assert isinstance(m, BayesianMeshNet) def test_output_shape_single_class(self): m = BayesianMeshNet(n_classes=1, in_channels=1, filters=8, receptive_field=37) x = torch.zeros(2, 1, 16, 16, 16) out = _run(m, x) assert out.shape == (2, 1, 16, 16, 16) def test_output_shape_multi_class(self): m = BayesianMeshNet(n_classes=4, in_channels=1, filters=8, receptive_field=37) x = torch.zeros(2, 1, 16, 16, 16) out = _run(m, x) assert out.shape == (2, 4, 16, 16, 16) def test_kl_accumulated(self): m = BayesianMeshNet(n_classes=1, in_channels=1, filters=8, receptive_field=37) x = torch.zeros(2, 1, 16, 16, 16) _run(m, x) assert accumulate_kl(m).item() > 0 def test_invalid_receptive_field(self): with pytest.raises(ValueError, match="receptive_field"): BayesianMeshNet(receptive_field=99) def test_all_dilation_schedules(self): for rf in [37, 67, 129]: m = BayesianMeshNet( n_classes=1, in_channels=1, filters=4, receptive_field=rf ) x = torch.zeros(2, 1, 8, 8, 8) out = _run(m, x) assert out.shape == (2, 1, 8, 8, 8) def test_factory_function(self): m = bayesian_meshnet(n_classes=2, in_channels=1, filters=4, receptive_field=37) assert isinstance(m, BayesianMeshNet) def test_kl_weight_attribute(self): m = BayesianMeshNet(kl_weight=1e-4) assert m.kl_weight == pytest.approx(1e-4) ================================================ FILE: nobrainer/tests/unit/test_class_weights.py ================================================ """Unit tests for class weight computation and weighted losses.""" from __future__ import annotations import numpy as np import torch from nobrainer.losses import DiceCELoss, compute_class_weights, weighted_cross_entropy class TestComputeClassWeights: def test_uniform_distribution(self, tmp_path): """Equal class counts → all weights ≈ 1.""" import nibabel as nib # Create 2-class volume with equal counts arr = np.zeros((10, 10, 10), dtype=np.int32) arr[:5] = 1 # half zeros, half ones nib.save(nib.Nifti1Image(arr, np.eye(4)), str(tmp_path / "lbl.nii.gz")) w = compute_class_weights([str(tmp_path / "lbl.nii.gz")], n_classes=2) assert w.shape == (2,) assert torch.allclose(w, torch.ones(2), atol=0.01) def test_imbalanced_gives_higher_weight_to_rare(self, tmp_path): """Rare class should get higher weight.""" import nibabel as nib arr = np.zeros((10, 10, 10), dtype=np.int32) arr[0, 0, 0] = 1 # class 1 is very rare nib.save(nib.Nifti1Image(arr, np.eye(4)), str(tmp_path / "lbl.nii.gz")) w = compute_class_weights([str(tmp_path / "lbl.nii.gz")], n_classes=2) assert w[1] > w[0] # rare class gets higher weight def test_median_frequency_method(self, tmp_path): import nibabel as nib arr = np.zeros((10, 10, 10), dtype=np.int32) arr[:2] = 1 arr[:1] = 2 nib.save(nib.Nifti1Image(arr, np.eye(4)), str(tmp_path / "lbl.nii.gz")) w = compute_class_weights( [str(tmp_path / "lbl.nii.gz")], n_classes=3, method="median_frequency", ) assert w.shape == (3,) assert (w > 0).all() def test_max_samples(self, tmp_path): """max_samples limits the number of files scanned.""" import nibabel as nib for i in range(5): arr = np.full((4, 4, 4), i % 2, dtype=np.int32) nib.save( nib.Nifti1Image(arr, np.eye(4)), str(tmp_path / f"lbl_{i}.nii.gz"), ) paths = [str(tmp_path / f"lbl_{i}.nii.gz") for i in range(5)] w = compute_class_weights(paths, n_classes=2, max_samples=2) assert w.shape == (2,) class TestWeightedCrossEntropy: def test_with_weights(self): w = torch.tensor([0.5, 1.5]) loss_fn = weighted_cross_entropy(weight=w) pred = torch.randn(4, 2) target = torch.randint(0, 2, (4,)) loss = loss_fn(pred, target) assert loss.ndim == 0 assert torch.isfinite(loss) def test_without_weights(self): loss_fn = weighted_cross_entropy() pred = torch.randn(4, 2) target = torch.randint(0, 2, (4,)) loss = loss_fn(pred, target) assert torch.isfinite(loss) class TestDiceCELoss: def test_3d_segmentation(self): loss_fn = DiceCELoss(softmax=True) pred = torch.randn(2, 3, 8, 8, 8) # 3-class target = torch.randint(0, 3, (2, 8, 8, 8)) loss = loss_fn(pred, target) assert loss.ndim == 0 assert torch.isfinite(loss) def test_with_class_weights(self): w = torch.tensor([0.5, 1.0, 2.0]) loss_fn = DiceCELoss(weight=w, softmax=True) pred = torch.randn(2, 3, 8, 8, 8) target = torch.randint(0, 3, (2, 8, 8, 8)) loss = loss_fn(pred, target) assert torch.isfinite(loss) def test_loss_registry(self): from nobrainer.losses import get loss_cls = get("dice_ce") assert loss_cls is DiceCELoss ================================================ FILE: nobrainer/tests/unit/test_croissant.py ================================================ """Unit tests for nobrainer.processing.croissant helpers (T024).""" from __future__ import annotations import json from pathlib import Path from unittest.mock import MagicMock import nibabel as nib import numpy as np from nobrainer.processing.croissant import ( _sha256, validate_croissant, write_dataset_croissant, write_model_croissant, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_nifti(shape=(16, 16, 16), tmpdir: Path | None = None) -> str: """Write a synthetic NIfTI file and return its path.""" data = np.random.rand(*shape).astype(np.float32) img = nib.Nifti1Image(data, affine=np.eye(4)) path = tmpdir / f"vol_{np.random.randint(0, int(1e6))}.nii.gz" nib.save(img, str(path)) return str(path) def _make_fake_estimator(model_name="unet"): """Create a mock estimator with typical attributes.""" est = MagicMock() est.base_model = model_name est.model_args = {"channels": (4, 8), "strides": (2,)} est.n_classes_ = 2 est.block_shape_ = (16, 16, 16) est._optimizer_class = "Adam" est._optimizer_args = {"lr": "0.001"} est._loss_name = "CrossEntropyLoss" return est def _make_fake_dataset(tmp_path, n=2): """Create a mock dataset with real NIfTI files.""" data = [] for _ in range(n): img = _make_nifti((16, 16, 16), tmp_path) lbl = _make_nifti((16, 16, 16), tmp_path) data.append({"image": img, "label": lbl}) ds = MagicMock() ds.data = data ds.volume_shape = (16, 16, 16) ds.n_classes = 2 ds._block_shape = (16, 16, 16) return ds # --------------------------------------------------------------------------- # Tests: write_model_croissant # --------------------------------------------------------------------------- class TestWriteModelCroissant: def test_creates_valid_jsonld(self, tmp_path): """write_model_croissant() creates a valid JSON-LD file.""" est = _make_fake_estimator() ds = _make_fake_dataset(tmp_path) result = { "history": [ {"epoch": 1, "loss": 0.5}, {"epoch": 2, "loss": 0.4}, ], "checkpoint_path": None, } out = write_model_croissant(tmp_path, est, result, ds) assert out.exists() data = json.loads(out.read_text()) assert "@context" in data assert "@type" in data assert data["@type"] == "cr:Dataset" def test_required_provenance_fields(self, tmp_path): """Provenance must contain all required fields.""" est = _make_fake_estimator() ds = _make_fake_dataset(tmp_path) result = { "history": [ {"epoch": 1, "loss": 0.5}, {"epoch": 2, "loss": 0.4}, ], "checkpoint_path": None, } out = write_model_croissant(tmp_path, est, result, ds) data = json.loads(out.read_text()) prov = data["nobrainer:provenance"] assert "source_datasets" in prov assert "training_date" in prov assert "nobrainer_version" in prov assert "model_architecture" in prov def test_provenance_model_architecture(self, tmp_path): est = _make_fake_estimator("meshnet") ds = _make_fake_dataset(tmp_path) out = write_model_croissant(tmp_path, est, None, ds) data = json.loads(out.read_text()) assert data["nobrainer:provenance"]["model_architecture"] == "meshnet" def test_sha256_checksums_for_source_datasets(self, tmp_path): """Source datasets must have SHA256 checksums.""" est = _make_fake_estimator() ds = _make_fake_dataset(tmp_path, n=2) out = write_model_croissant(tmp_path, est, None, ds) data = json.loads(out.read_text()) sources = data["nobrainer:provenance"]["source_datasets"] assert len(sources) >= 1 for src in sources: assert "sha256" in src assert len(src["sha256"]) == 64 # SHA256 hex digest length class TestSHA256: def test_checksum_computed(self, tmp_path): """_sha256 returns a 64-char hex digest for a file.""" path = _make_nifti((4, 4, 4), tmp_path) digest = _sha256(path) assert isinstance(digest, str) assert len(digest) == 64 def test_deterministic(self, tmp_path): """Same file produces same checksum.""" path = _make_nifti((4, 4, 4), tmp_path) assert _sha256(path) == _sha256(path) # --------------------------------------------------------------------------- # Tests: validate_croissant # --------------------------------------------------------------------------- class TestValidateCroissant: def test_returns_true_on_valid(self, tmp_path): """validate_croissant() returns True on a valid file.""" est = _make_fake_estimator() ds = _make_fake_dataset(tmp_path) out = write_model_croissant(tmp_path, est, None, ds) assert validate_croissant(out) is True # --------------------------------------------------------------------------- # Tests: write_dataset_croissant # --------------------------------------------------------------------------- class TestWriteDatasetCroissant: def test_writes_dataset_metadata(self, tmp_path): """write_dataset_croissant() writes a valid JSON-LD.""" ds = _make_fake_dataset(tmp_path) out = write_dataset_croissant(tmp_path / "ds_croissant.json", ds) assert out.exists() data = json.loads(out.read_text()) assert "@context" in data assert "@type" in data assert data["@type"] == "cr:Dataset" def test_dataset_info_present(self, tmp_path): ds = _make_fake_dataset(tmp_path) out = write_dataset_croissant(tmp_path / "ds_croissant.json", ds) data = json.loads(out.read_text()) assert "nobrainer:dataset_info" in data info = data["nobrainer:dataset_info"] assert info["n_classes"] == 2 assert info["n_volumes"] == 2 def test_distribution_has_sha256(self, tmp_path): ds = _make_fake_dataset(tmp_path) out = write_dataset_croissant(tmp_path / "ds_croissant.json", ds) data = json.loads(out.read_text()) for item in data["distribution"]: assert "sha256" in item assert len(item["sha256"]) == 64 ================================================ FILE: nobrainer/tests/unit/test_dataset.py ================================================ """Unit tests for nobrainer.dataset.get_dataset().""" from pathlib import Path import tempfile import nibabel as nib import numpy as np import pytest from nobrainer.dataset import get_dataset # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_nifti(shape=(16, 16, 16), tmpdir: Path | None = None) -> str: """Write a synthetic NIfTI file and return its path.""" if tmpdir is None: tmpdir = Path(tempfile.mkdtemp()) data = np.random.rand(*shape).astype(np.float32) img = nib.Nifti1Image(data, affine=np.eye(4)) path = tmpdir / f"vol_{np.random.randint(0, 1e6)}.nii.gz" nib.save(img, str(path)) return str(path) # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestGetDataset: def test_batch_shape_image_only(self, tmp_path): """Verify batch shape (B, 1, D, H, W) for image-only dataset.""" paths = [_make_nifti((16, 16, 16), tmp_path) for _ in range(3)] loader = get_dataset( image_paths=paths, batch_size=2, num_workers=0, cache_rate=0.0, ) batch = next(iter(loader)) assert "image" in batch assert batch["image"].ndim == 5 # (B, C, D, H, W) assert batch["image"].shape[0] == 2 # batch size assert batch["image"].shape[1] == 1 # channel def test_batch_shape_with_labels(self, tmp_path): """Verify both image and label tensors are returned.""" image_paths = [_make_nifti((16, 16, 16), tmp_path) for _ in range(2)] label_paths = [_make_nifti((16, 16, 16), tmp_path) for _ in range(2)] loader = get_dataset( image_paths=image_paths, label_paths=label_paths, batch_size=2, num_workers=0, cache_rate=0.0, ) batch = next(iter(loader)) assert "image" in batch assert "label" in batch def test_mismatch_raises(self, tmp_path): """Mismatched image/label list lengths should raise ValueError.""" paths = [_make_nifti((16, 16, 16), tmp_path) for _ in range(2)] with pytest.raises(ValueError, match="len"): get_dataset( image_paths=paths, label_paths=paths[:1], batch_size=1, num_workers=0, ) def test_augment_flag(self, tmp_path): """augment=True should not crash the dataloader.""" paths = [_make_nifti((16, 16, 16), tmp_path) for _ in range(2)] loader = get_dataset( image_paths=paths, batch_size=2, num_workers=0, augment=True, cache_rate=0.0, ) batch = next(iter(loader)) assert batch["image"].shape[1] == 1 def test_returns_dataloader(self, tmp_path): paths = [_make_nifti((16, 16, 16), tmp_path)] loader = get_dataset( image_paths=paths, batch_size=1, num_workers=0, cache_rate=0.0 ) from torch.utils.data import DataLoader assert isinstance(loader, DataLoader) ================================================ FILE: nobrainer/tests/unit/test_dataset_builder.py ================================================ """Unit tests for nobrainer.processing.dataset.Dataset fluent builder (T013).""" from __future__ import annotations import json from pathlib import Path import nibabel as nib import numpy as np from torch.utils.data import DataLoader from nobrainer.processing.dataset import Dataset # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_nifti(shape=(16, 16, 16), tmpdir: Path | None = None) -> str: """Write a synthetic NIfTI file and return its path.""" data = np.random.rand(*shape).astype(np.float32) img = nib.Nifti1Image(data, affine=np.eye(4)) path = tmpdir / f"vol_{np.random.randint(0, int(1e6))}.nii.gz" nib.save(img, str(path)) return str(path) def _make_file_pairs(n, shape, tmpdir): """Create n (image, label) NIfTI file pairs.""" pairs = [] for _ in range(n): img_path = _make_nifti(shape, tmpdir) lbl_path = _make_nifti(shape, tmpdir) pairs.append((img_path, lbl_path)) return pairs # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestFromFiles: def test_tuple_format(self, tmp_path): """from_files() accepts list of (image, label) tuples.""" pairs = _make_file_pairs(3, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=2) assert len(ds.data) == 3 assert all("image" in d and "label" in d for d in ds.data) def test_dict_format(self, tmp_path): """from_files() accepts list of dicts.""" pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) dicts = [{"image": img, "label": lbl} for img, lbl in pairs] ds = Dataset.from_files(dicts, block_shape=(16, 16, 16), n_classes=2) assert len(ds.data) == 2 def test_volume_shape_detected(self, tmp_path): """from_files() detects volume_shape from the first NIfTI.""" pairs = _make_file_pairs(1, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) assert ds.volume_shape == (16, 16, 16) class TestFluentChaining: def test_batch_returns_self(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) result = ds.batch(4) assert result is ds def test_shuffle_returns_self(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) result = ds.shuffle() assert result is ds def test_augment_returns_self(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) result = ds.augment() assert result is ds def test_chaining(self, tmp_path): """Chaining .batch().shuffle().augment() returns the same instance.""" pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) result = ds.batch(2).shuffle().augment() assert result is ds assert ds._batch_size == 2 assert ds._shuffle is True assert ds._augment is True class TestSplit: def test_split_sizes(self, tmp_path): """split() returns two Datasets with correct combined size.""" pairs = _make_file_pairs(10, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=2) train, val = ds.split(eval_size=0.2) assert len(train.data) + len(val.data) == 10 assert len(val.data) == 2 # int(10 * 0.2) = 2 def test_split_returns_datasets(self, tmp_path): pairs = _make_file_pairs(4, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) train, val = ds.split(eval_size=0.25) assert isinstance(train, Dataset) assert isinstance(val, Dataset) class TestDataloader: def test_returns_dataloader(self, tmp_path): """dataloader property returns a torch DataLoader.""" pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=2).batch(2) loader = ds.dataloader assert isinstance(loader, DataLoader) def test_batch_produces_data(self, tmp_path): """DataLoader yields batches with image data.""" pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=2).batch(2) batch = next(iter(ds.dataloader)) # MONAI DataLoader returns dict with "image" key assert "image" in batch assert batch["image"].ndim == 5 # (B, C, D, H, W) class TestMetadataProperties: def test_batch_size(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)).batch(4) assert ds.batch_size == 4 def test_block_shape(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) assert ds.block_shape == (16, 16, 16) def test_volume_shape(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16)) assert ds.volume_shape == (16, 16, 16) def test_n_classes(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=3) assert ds.n_classes == 3 class TestToCroissant: def test_writes_valid_jsonld(self, tmp_path): """to_croissant() writes valid JSON-LD with @context and fields.""" pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=2) out = ds.to_croissant(tmp_path / "dataset_croissant.json") assert out.exists() data = json.loads(out.read_text()) assert "@context" in data assert "@type" in data assert data["@type"] == "cr:Dataset" def test_has_dataset_info(self, tmp_path): pairs = _make_file_pairs(2, (16, 16, 16), tmp_path) ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=2) out = ds.to_croissant(tmp_path / "dataset_croissant.json") data = json.loads(out.read_text()) assert "nobrainer:dataset_info" in data info = data["nobrainer:dataset_info"] assert info["n_classes"] == 2 assert info["n_volumes"] == 2 ================================================ FILE: nobrainer/tests/unit/test_datasets_openneuro.py ================================================ """Unit tests for nobrainer.datasets.openneuro.""" from __future__ import annotations from pathlib import Path from unittest.mock import patch import pytest class TestWriteManifest: """Test write_manifest without DataLad.""" def test_creates_csv(self, tmp_path): from nobrainer.datasets.openneuro import write_manifest pairs = [ { "subject_id": f"sub-{i:02d}", "t1w_path": f"/t1_{i}.nii.gz", "label_path": f"/lbl_{i}.nii.gz", } for i in range(5) ] csv_path = write_manifest(pairs, tmp_path / "manifest.csv") assert csv_path.exists() import csv with open(csv_path) as f: rows = list(csv.DictReader(f)) assert len(rows) == 5 splits = {r["split"] for r in rows} assert splits <= {"train", "val", "test"} def test_split_ratios(self, tmp_path): from nobrainer.datasets.openneuro import write_manifest pairs = [ { "subject_id": f"sub-{i:02d}", "t1w_path": f"/t1_{i}.nii.gz", "label_path": f"/lbl_{i}.nii.gz", } for i in range(10) ] write_manifest(pairs, tmp_path / "m.csv", split_ratios=(60, 20, 20)) import csv with open(tmp_path / "m.csv") as f: rows = list(csv.DictReader(f)) n_train = sum(1 for r in rows if r["split"] == "train") assert n_train == 6 # 60% of 10 def test_dataset_id_column(self, tmp_path): from nobrainer.datasets.openneuro import write_manifest pairs = [ { "subject_id": "sub-01", "dataset_id": "ds000114", "t1w_path": "/t1.nii.gz", "label_path": "/lbl.nii.gz", } ] csv_path = write_manifest(pairs, tmp_path / "m.csv") import csv with open(csv_path) as f: reader = csv.DictReader(f) row = next(reader) assert row["dataset_id"] == "ds000114" class TestGlobDataset: """Test glob_dataset (no DataLad needed).""" def test_finds_files(self, tmp_path): from nobrainer.datasets.openneuro import glob_dataset (tmp_path / "sub-01" / "anat").mkdir(parents=True) (tmp_path / "sub-01" / "anat" / "sub-01_T1w.nii.gz").touch() (tmp_path / "sub-02" / "anat").mkdir(parents=True) (tmp_path / "sub-02" / "anat" / "sub-02_T1w.nii.gz").touch() files = glob_dataset(tmp_path, "sub-*/anat/*_T1w.nii.gz") assert len(files) == 2 def test_no_matches(self, tmp_path): from nobrainer.datasets.openneuro import glob_dataset files = glob_dataset(tmp_path, "sub-*/anat/*_T1w.nii.gz") assert files == [] class TestExtractSubjectId: def test_from_bids_path(self, tmp_path): from nobrainer.datasets.openneuro import _extract_subject_id p = tmp_path / "sub-03" / "anat" / "sub-03_T1w.nii.gz" assert _extract_subject_id(p) == "sub-03" def test_from_filename(self): from nobrainer.datasets.openneuro import _extract_subject_id p = Path("sub-99_desc-preproc_T1w.nii.gz") assert _extract_subject_id(p) == "sub-99" class TestFileOk: def test_real_file(self, tmp_path): from nobrainer.datasets.openneuro import _file_ok f = tmp_path / "real.nii.gz" f.write_bytes(b"data") assert _file_ok(f) def test_empty_file(self, tmp_path): from nobrainer.datasets.openneuro import _file_ok f = tmp_path / "empty.nii.gz" f.touch() assert not _file_ok(f) def test_missing_file(self, tmp_path): from nobrainer.datasets.openneuro import _file_ok assert not _file_ok(tmp_path / "missing.nii.gz") class TestImportGuard: """Test that missing datalad gives a clear error.""" def test_install_without_datalad(self): from nobrainer.datasets.openneuro import install_dataset with patch.dict("sys.modules", {"datalad": None, "datalad.api": None}): with pytest.raises(ImportError, match="DataLad"): install_dataset("ds000114", "/tmp/test") ================================================ FILE: nobrainer/tests/unit/test_estimator_generation.py ================================================ """Unit tests for nobrainer.processing.generation.Generation estimator (T029).""" from __future__ import annotations import json import nibabel as nib import torch from torch.utils.data import DataLoader, TensorDataset from nobrainer.processing.generation import Generation # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- SPATIAL = 4 GAN_ARGS = { "latent_size": 8, "fmap_base": 16, "fmap_max": 16, "resolution_schedule": [4], "steps_per_phase": 100, } class _FakeDataset: """Minimal dataset-like object for Generation.fit().""" def __init__(self, loader): self._loader = loader self.data = [] @property def dataloader(self): return self._loader def _make_fake_dataset(n=4, spatial=SPATIAL, batch_size=2): """Build a fake dataset with tiny synthetic volumes.""" imgs = torch.randn(n, 1, spatial, spatial, spatial) loader = DataLoader(TensorDataset(imgs), batch_size=batch_size) return _FakeDataset(loader) # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestGenerationFit: def test_fit_returns_self(self): """Generation('progressivegan').fit() returns self.""" ds = _make_fake_dataset() gen = Generation("progressivegan", model_args=GAN_ARGS, multi_gpu=False) result = gen.fit( ds, epochs=10, accelerator="cpu", enable_progress_bar=False, ) assert result is gen def test_model_created_after_fit(self): ds = _make_fake_dataset() gen = Generation("progressivegan", model_args=GAN_ARGS, multi_gpu=False) gen.fit( ds, epochs=5, accelerator="cpu", enable_progress_bar=False, ) assert gen.model_ is not None class TestGenerationGenerate: def test_generate_returns_list_of_nifti(self): """.generate(2) returns list of 2 nibabel.Nifti1Image.""" ds = _make_fake_dataset() gen = Generation("progressivegan", model_args=GAN_ARGS, multi_gpu=False) gen.fit( ds, epochs=5, accelerator="cpu", enable_progress_bar=False, ) images = gen.generate(2) assert isinstance(images, list) assert len(images) == 2 for img in images: assert isinstance(img, nib.Nifti1Image) class TestGenerationSave: def test_save_creates_croissant(self, tmp_path): """.save() creates croissant.json.""" ds = _make_fake_dataset() gen = Generation("progressivegan", model_args=GAN_ARGS, multi_gpu=False) gen.fit( ds, epochs=5, accelerator="cpu", enable_progress_bar=False, ) save_dir = tmp_path / "gen_out" gen.save(save_dir) assert (save_dir / "model.pth").exists() assert (save_dir / "croissant.json").exists() data = json.loads((save_dir / "croissant.json").read_text()) assert "@context" in data assert "nobrainer:provenance" in data ================================================ FILE: nobrainer/tests/unit/test_estimator_segmentation.py ================================================ """Unit tests for nobrainer.processing.segmentation.Segmentation estimator (T023).""" from __future__ import annotations import json from pathlib import Path import nibabel as nib import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from nobrainer.processing.segmentation import Segmentation # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- SPATIAL = 16 N_CLASSES = 2 def _make_nifti(shape=(16, 16, 16), tmpdir: Path | None = None) -> str: """Write a synthetic NIfTI file and return its path.""" data = np.random.rand(*shape).astype(np.float32) img = nib.Nifti1Image(data, affine=np.eye(4)) path = tmpdir / f"vol_{np.random.randint(0, int(1e6))}.nii.gz" nib.save(img, str(path)) return str(path) def _make_tiny_loader(n=4, spatial=SPATIAL, n_classes=N_CLASSES, batch_size=2): """Create a tiny DataLoader with tuple batches for training.""" x = torch.randn(n, 1, spatial, spatial, spatial) y = torch.randint(0, n_classes, (n, spatial, spatial, spatial)) ds = TensorDataset(x, y) return DataLoader(ds, batch_size=batch_size) class _FakeDataset: """Minimal object mimicking the Dataset builder for Segmentation.fit().""" def __init__(self, loader, block_shape, volume_shape, n_classes): self._loader = loader self._block_shape = block_shape self.volume_shape = volume_shape self.n_classes = n_classes @property def block_shape(self): return self._block_shape @property def dataloader(self): return self._loader def _make_fake_dataset(n=4, spatial=SPATIAL, n_classes=N_CLASSES, batch_size=2): """Build a FakeDataset with a tiny DataLoader.""" loader = _make_tiny_loader(n, spatial, n_classes, batch_size) return _FakeDataset( loader, block_shape=(spatial, spatial, spatial), volume_shape=(spatial, spatial, spatial), n_classes=n_classes, ) # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestSegmentationFit: def test_fit_returns_self(self): """Segmentation('unet').fit(ds, epochs=2) returns self.""" ds = _make_fake_dataset() seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, multi_gpu=False, ) result = seg.fit(ds, epochs=2) assert result is seg def test_model_created_after_fit(self): ds = _make_fake_dataset() seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, multi_gpu=False, ) seg.fit(ds, epochs=1) assert seg.model_ is not None assert isinstance(seg.model_, nn.Module) class TestSegmentationPredict: def test_predict_returns_nifti(self, tmp_path): """.predict() returns nibabel.Nifti1Image with correct shape.""" ds = _make_fake_dataset() seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, multi_gpu=False, ) seg.fit(ds, epochs=1) # Create a test volume vol_path = _make_nifti((SPATIAL, SPATIAL, SPATIAL), tmp_path) result = seg.predict(vol_path, block_shape=(SPATIAL, SPATIAL, SPATIAL)) assert isinstance(result, nib.Nifti1Image) assert result.shape[:3] == (SPATIAL, SPATIAL, SPATIAL) class TestSegmentationSaveLoad: def test_save_creates_files(self, tmp_path): """.save() creates model.pth and croissant.json.""" ds = _make_fake_dataset() seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, multi_gpu=False, ) seg.fit(ds, epochs=1) save_dir = tmp_path / "model_out" seg.save(save_dir) assert (save_dir / "model.pth").exists() assert (save_dir / "croissant.json").exists() def test_croissant_provenance_fields(self, tmp_path): """croissant.json contains all provenance fields.""" ds = _make_fake_dataset() seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, multi_gpu=False, ) seg.fit(ds, epochs=1) save_dir = tmp_path / "model_out" seg.save(save_dir) data = json.loads((save_dir / "croissant.json").read_text()) prov = data["nobrainer:provenance"] assert "source_datasets" in prov assert "training_date" in prov assert "nobrainer_version" in prov assert "model_architecture" in prov assert prov["model_architecture"] == "unet" def test_load_roundtrip(self, tmp_path): """.load() round-trip produces same prediction output.""" ds = _make_fake_dataset() seg = Segmentation( "unet", model_args={"channels": (4, 8), "strides": (2,)}, multi_gpu=False, ) seg.fit(ds, epochs=1) # Get prediction before save test_vol = np.random.rand(SPATIAL, SPATIAL, SPATIAL).astype(np.float32) pred_before = seg.predict(test_vol, block_shape=(SPATIAL, SPATIAL, SPATIAL)) # Save and reload save_dir = tmp_path / "model_out" seg.save(save_dir) loaded = Segmentation.load(save_dir, multi_gpu=False) # Predict again pred_after = loaded.predict(test_vol, block_shape=(SPATIAL, SPATIAL, SPATIAL)) np.testing.assert_array_equal( np.asarray(pred_before.dataobj), np.asarray(pred_after.dataobj), ) ================================================ FILE: nobrainer/tests/unit/test_experiment.py ================================================ """Unit tests for nobrainer.experiment tracking.""" from __future__ import annotations import json from nobrainer.experiment import ExperimentTracker class TestExperimentTracker: def test_local_logging(self, tmp_path): tracker = ExperimentTracker( output_dir=tmp_path, config={"lr": 0.001}, use_wandb=False ) tracker.log({"epoch": 1, "loss": 0.5}) tracker.log({"epoch": 2, "loss": 0.3}) tracker.finish() # Check JSONL lines = (tmp_path / "metrics.jsonl").read_text().strip().split("\n") assert len(lines) == 2 assert json.loads(lines[0])["loss"] == 0.5 # Check CSV csv_lines = (tmp_path / "metrics.csv").read_text().strip().split("\n") assert len(csv_lines) == 3 # header + 2 rows assert "epoch" in csv_lines[0] # Check config config = json.loads((tmp_path / "config.json").read_text()) assert config["lr"] == 0.001 def test_callback(self, tmp_path): tracker = ExperimentTracker(output_dir=tmp_path, use_wandb=False) cb = tracker.callback(variant="test") # Simulate training callback cb(0, {"loss": 1.5}, None) # (epoch, logs_dict, model) cb(1, {"loss": 0.8}, None) tracker.finish() lines = (tmp_path / "metrics.jsonl").read_text().strip().split("\n") assert len(lines) == 2 row = json.loads(lines[0]) assert row["epoch"] == 0 assert row["loss"] == 1.5 assert row["variant"] == "test" def test_no_wandb_by_default(self, tmp_path): tracker = ExperimentTracker(output_dir=tmp_path) # Should not fail even without wandb installed tracker.log({"x": 1}) tracker.finish() ================================================ FILE: nobrainer/tests/unit/test_generative.py ================================================ """Unit tests for ProgressiveGAN and DCGAN (CPU smoke tests).""" from __future__ import annotations import pytorch_lightning as pl import torch from torch.utils.data import DataLoader, TensorDataset from nobrainer.models.generative import DCGAN, ProgressiveGAN, dcgan, progressivegan # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _tiny_loader(batch_size: int = 2, spatial: int = 4) -> DataLoader: """Return a DataLoader with synthetic 3-D volumes.""" imgs = torch.randn(4, 1, spatial, spatial, spatial) return DataLoader(TensorDataset(imgs), batch_size=batch_size) # --------------------------------------------------------------------------- # ProgressiveGAN # --------------------------------------------------------------------------- class TestProgressiveGAN: def test_construction(self): m = ProgressiveGAN( latent_size=8, fmap_base=16, fmap_max=16, resolution_schedule=[4, 8] ) assert isinstance(m, ProgressiveGAN) def test_factory_function(self): m = progressivegan( latent_size=8, fmap_base=16, fmap_max=16, resolution_schedule=[4, 8] ) assert isinstance(m, ProgressiveGAN) def test_generator_output_shape(self): m = ProgressiveGAN( latent_size=8, fmap_base=16, fmap_max=16, resolution_schedule=[4] ) m.generator.current_level = 0 m.generator.alpha = 1.0 z = torch.randn(2, 8) out = m.generator(z) assert out.shape[0] == 2 assert out.shape[1] == 1 def test_discriminator_output_shape(self): m = ProgressiveGAN( latent_size=8, fmap_base=16, fmap_max=16, resolution_schedule=[4] ) m.discriminator.current_level = 0 img = torch.randn(2, 1, 4, 4, 4) out = m.discriminator(img) assert out.shape == (2, 1) def test_training_step_losses_finite(self): """5-step CPU training smoke test.""" m = ProgressiveGAN( latent_size=8, fmap_base=16, fmap_max=16, resolution_schedule=[4], steps_per_phase=10, ) loader = _tiny_loader(batch_size=2, spatial=4) trainer = pl.Trainer( max_steps=5, accelerator="cpu", enable_checkpointing=False, logger=False, enable_progress_bar=False, ) trainer.fit(m, loader) # Verify that logged losses are finite assert m._step_count > 0 def test_alpha_schedule(self): m = ProgressiveGAN( latent_size=8, fmap_base=16, fmap_max=16, resolution_schedule=[4, 8], steps_per_phase=10, ) m._step_count = 5 m.on_train_batch_end() assert 0.0 <= m.generator.alpha <= 1.0 # --------------------------------------------------------------------------- # DCGAN # --------------------------------------------------------------------------- class TestDCGAN: def test_construction(self): m = DCGAN(latent_size=8, n_filters=4) assert isinstance(m, DCGAN) def test_factory_function(self): m = dcgan(latent_size=8, n_filters=4) assert isinstance(m, DCGAN) def test_generator_output_shape(self): m = DCGAN(latent_size=8, n_filters=4) z = torch.randn(2, 8) out = m.generator(z) assert out.shape[0] == 2 assert out.shape[1] == 1 def test_discriminator_output_shape(self): m = DCGAN(latent_size=8, n_filters=4) img = torch.randn(2, 1, 64, 64, 64) out = m.discriminator(img) assert out.shape == (2, 1) def test_training_step_losses_finite(self): """5-step CPU training smoke test.""" m = DCGAN(latent_size=8, n_filters=4) loader = _tiny_loader(batch_size=2, spatial=4) trainer = pl.Trainer( max_steps=5, accelerator="cpu", enable_checkpointing=False, logger=False, enable_progress_bar=False, ) trainer.fit(m, loader) # No assertion needed — if fit() completes without error, losses were finite def test_configure_optimizers(self): m = DCGAN(latent_size=8, n_filters=4) opts = m.configure_optimizers() assert len(opts) == 2 # (opt_g, opt_d) ================================================ FILE: nobrainer/tests/unit/test_gpu.py ================================================ """Unit tests for nobrainer.gpu utilities.""" from __future__ import annotations import torch from nobrainer.gpu import get_device, gpu_count, gpu_info, scale_for_multi_gpu class TestGetDevice: def test_returns_torch_device(self): d = get_device() assert isinstance(d, torch.device) def test_device_type_known(self): d = get_device() assert d.type in ("cuda", "mps", "cpu") class TestGpuCount: def test_returns_int(self): n = gpu_count() assert isinstance(n, int) assert n >= 0 class TestGpuInfo: def test_returns_list(self): info = gpu_info() assert isinstance(info, list) if torch.cuda.is_available(): assert len(info) > 0 assert "name" in info[0] assert "memory_gb" in info[0] class TestScaleForMultiGpu: def test_no_gpu_returns_base(self): if torch.cuda.is_available(): return # skip on GPU machines eff, per, n = scale_for_multi_gpu(base_batch_size=32) assert eff == 32 assert per == 32 assert n == 0 def test_simple_division(self): # Without model, just divides eff, per, n = scale_for_multi_gpu(base_batch_size=32) if n > 0: assert eff == per * n ================================================ FILE: nobrainer/tests/unit/test_io_weights.py ================================================ """Unit tests for convert_weights() in nobrainer.io.""" from pathlib import Path import h5py import numpy as np import torch import torch.nn as nn from nobrainer.io import convert_weights # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class _SimplePT(nn.Module): """Minimal PyTorch model for weight-conversion tests.""" def __init__(self): super().__init__() self.conv = nn.Conv3d(1, 4, 3, padding=1, bias=True) self.bn = nn.BatchNorm3d(4) def forward(self, x): return self.bn(self.conv(x)) def _write_synthetic_h5(path: str, model: nn.Module) -> None: """Write a synthetic H5 file that mimics Keras weight layout.""" with h5py.File(path, "w") as hf: sd = model.state_dict() for k, v in sd.items(): w = v.numpy() # Transpose conv weights back to Keras format for the test if w.ndim == 5: w = np.transpose(w, (2, 3, 4, 1, 0)) # Cout,Cin,D,H,W → D,H,W,Cin,Cout hf.create_dataset( k.replace(".", "/") + "/kernel" if w.ndim == 5 else k, data=w ) # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestConvertWeights: def test_returns_dict(self, tmp_path): model = _SimplePT() h5_path = str(tmp_path / "weights.h5") # Write a minimal H5 that has some datasets with h5py.File(h5_path, "w") as hf: hf.create_dataset("dummy", data=np.zeros(4)) result = convert_weights(h5_path, model) assert isinstance(result, dict) def test_output_pth_written(self, tmp_path): model = _SimplePT() h5_path = str(tmp_path / "weights.h5") pth_path = str(tmp_path / "weights.pth") with h5py.File(h5_path, "w") as hf: hf.create_dataset("dummy", data=np.zeros(4)) convert_weights(h5_path, model, output_path=pth_path) assert Path(pth_path).exists() loaded = torch.load(pth_path, map_location="cpu", weights_only=True) assert isinstance(loaded, dict) def test_state_dict_keys_preserved(self, tmp_path): """Model state dict should have same keys before and after conversion.""" model = _SimplePT() original_keys = set(model.state_dict().keys()) h5_path = str(tmp_path / "weights.h5") with h5py.File(h5_path, "w") as hf: hf.create_dataset("dummy", data=np.zeros(4)) convert_weights(h5_path, model) assert set(model.state_dict().keys()) == original_keys ================================================ FILE: nobrainer/tests/unit/test_io_zarr.py ================================================ """Unit tests for NIfTI <-> Zarr v3 conversion.""" from __future__ import annotations import nibabel as nib import numpy as np import pytest zarr = pytest.importorskip("zarr", reason="zarr not installed") from nobrainer.io import nifti_to_zarr, zarr_to_nifti # noqa: E402 def _make_nifti(tmp_path, shape=(32, 32, 32)): """Create a synthetic NIfTI file and return path + data.""" data = np.random.rand(*shape).astype(np.float32) affine = np.diag([2.0, 2.0, 2.0, 1.0]) img = nib.Nifti1Image(data, affine) path = str(tmp_path / "test.nii.gz") nib.save(img, path) return path, data, affine class TestNiftiToZarr: def test_creates_valid_store(self, tmp_path): nii_path, data, _ = _make_nifti(tmp_path) zarr_path = nifti_to_zarr(nii_path, tmp_path / "out.zarr") store = zarr.open_group(str(zarr_path), mode="r") arr = np.asarray(store["0"]) assert arr.shape == data.shape assert arr.dtype == np.float32 def test_provenance_stored(self, tmp_path): nii_path, _, _ = _make_nifti(tmp_path) zarr_path = nifti_to_zarr(nii_path, tmp_path / "out.zarr") store = zarr.open_group(str(zarr_path), mode="r") prov = store.attrs.get("nobrainer_provenance") assert prov is not None assert "source_file" in prov assert "created_at" in prov assert "nobrainer_version" in prov assert prov["tool"] == "nobrainer.io.nifti_to_zarr" def test_multi_resolution_pyramid(self, tmp_path): nii_path, data, _ = _make_nifti(tmp_path, shape=(64, 64, 64)) zarr_path = nifti_to_zarr(nii_path, tmp_path / "pyramid.zarr", levels=3) store = zarr.open_group(str(zarr_path), mode="r") # Level 0: full resolution assert np.asarray(store["0"]).shape == (64, 64, 64) # Downsampled levels should have smaller shapes level1 = np.asarray(store["1"]) assert all(s <= 64 for s in level1.shape) level2 = np.asarray(store["2"]) assert all(s <= level1.shape[i] for i, s in enumerate(level2.shape)) class TestZarrToNifti: def test_round_trip_shape(self, tmp_path): """NIfTI -> Zarr -> NIfTI preserves shape.""" nii_path, data, _ = _make_nifti(tmp_path) zarr_path = nifti_to_zarr(nii_path, tmp_path / "rt.zarr") rt_path = zarr_to_nifti(zarr_path, tmp_path / "roundtrip.nii.gz") rt_img = nib.load(str(rt_path)) assert rt_img.shape == data.shape def test_round_trip_data(self, tmp_path): """NIfTI -> Zarr -> NIfTI preserves data values.""" nii_path, data, _ = _make_nifti(tmp_path) zarr_path = nifti_to_zarr(nii_path, tmp_path / "rt.zarr") rt_path = zarr_to_nifti(zarr_path, tmp_path / "roundtrip.nii.gz") rt_img = nib.load(str(rt_path)) rt_data = np.asarray(rt_img.dataobj, dtype=np.float32) # Value range should be preserved assert abs(rt_data.mean() - data.mean()) < 0.1 assert rt_data.min() >= 0 assert rt_data.max() <= 1.0 + 0.01 def test_round_trip_level1(self, tmp_path): """Exporting level 1 gives a smaller shape.""" nii_path, _, _ = _make_nifti(tmp_path, shape=(64, 64, 64)) zarr_path = nifti_to_zarr(nii_path, tmp_path / "pyr.zarr", levels=2) rt_path = zarr_to_nifti(zarr_path, tmp_path / "level1.nii.gz", level=1) rt_img = nib.load(str(rt_path)) # Level 1 should be smaller than full resolution assert all(s <= 64 for s in rt_img.shape) ================================================ FILE: nobrainer/tests/unit/test_layers.py ================================================ """Unit tests for nobrainer.layers (PyTorch implementations).""" import pytest import torch from nobrainer.layers import ( BernoulliDropout, ConcreteDropout, GaussianDropout, MaxPool4D, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- SHAPE_3D = (1, 1, 8, 8, 8) SHAPE_4D = (1, 1, 2, 8, 8, 8) # (N, C, V, D, H, W) @pytest.fixture def x3d(): return torch.ones(SHAPE_3D) @pytest.fixture def x4d(): return torch.ones(SHAPE_4D, requires_grad=False) # --------------------------------------------------------------------------- # BernoulliDropout # --------------------------------------------------------------------------- class TestBernoulliDropout: def test_forward_shape(self, x3d): layer = BernoulliDropout(rate=0.3, is_monte_carlo=True) layer.train() out = layer(x3d) assert out.shape == x3d.shape def test_passthrough_eval_scale(self, x3d): """With scale_during_training=True, eval mode returns x unchanged.""" layer = BernoulliDropout( rate=0.5, is_monte_carlo=False, scale_during_training=True ) layer.eval() out = layer(x3d) assert torch.allclose(out, x3d) def test_passthrough_eval_noscale(self, x3d): """With scale_during_training=False, eval mode returns x * keep_prob.""" rate = 0.3 layer = BernoulliDropout( rate=rate, is_monte_carlo=False, scale_during_training=False ) layer.eval() out = layer(x3d) assert torch.allclose(out, x3d * (1.0 - rate)) def test_gradient_flow(self, x3d): x = x3d.clone().requires_grad_(True) layer = BernoulliDropout(rate=0.3, is_monte_carlo=True, seed=42) layer.train() out = layer(x) loss = out.sum() loss.backward() assert x.grad is not None def test_invalid_rate(self): with pytest.raises(ValueError): BernoulliDropout(rate=1.0, is_monte_carlo=True) def test_mc_applies_in_eval(self, x3d): """is_monte_carlo=True applies mask even in eval mode.""" torch.manual_seed(0) layer = BernoulliDropout(rate=0.9, is_monte_carlo=True, seed=1) layer.eval() out = layer(x3d) # With high rate some outputs should be zero assert out.sum() < x3d.sum() # --------------------------------------------------------------------------- # ConcreteDropout # --------------------------------------------------------------------------- class TestConcreteDropout: def test_forward_shape(self, x3d): N, C, D, H, W = x3d.shape layer = ConcreteDropout(in_channels=C, is_monte_carlo=True) layer.train() out = layer(x3d) assert out.shape == x3d.shape def test_kl_positive(self, x3d): N, C, D, H, W = x3d.shape layer = ConcreteDropout(in_channels=C, is_monte_carlo=True) layer.train() _ = layer(x3d) assert layer.kl_loss.item() > 0.0 def test_gradient_flow(self, x3d): N, C, D, H, W = x3d.shape x = x3d.clone().requires_grad_(True) layer = ConcreteDropout(in_channels=C, is_monte_carlo=True) layer.train() out = layer(x) # Gradient should flow through p_logit (learnable) loss = out.sum() + layer.kl_loss loss.backward() assert layer.p_logit.grad is not None def test_p_post_clipped(self, x3d): N, C, D, H, W = x3d.shape layer = ConcreteDropout(in_channels=C) p = layer.p_post assert (p >= 0.05).all() and (p <= 0.95).all() def test_passthrough_eval(self, x3d): N, C, D, H, W = x3d.shape layer = ConcreteDropout( in_channels=C, is_monte_carlo=False, use_expectation=False ) layer.eval() out = layer(x3d) assert torch.allclose(out, x3d) # --------------------------------------------------------------------------- # GaussianDropout # --------------------------------------------------------------------------- class TestGaussianDropout: def test_forward_shape(self, x3d): layer = GaussianDropout(rate=0.3, is_monte_carlo=True) layer.train() out = layer(x3d) assert out.shape == x3d.shape def test_passthrough_eval(self, x3d): layer = GaussianDropout(rate=0.3, is_monte_carlo=False) layer.eval() out = layer(x3d) assert torch.allclose(out, x3d) def test_gradient_flow(self, x3d): x = x3d.clone().requires_grad_(True) layer = GaussianDropout(rate=0.3, is_monte_carlo=True, seed=42) layer.train() out = layer(x) out.sum().backward() assert x.grad is not None def test_mc_in_eval(self, x3d): """is_monte_carlo=True adds noise even in eval mode.""" torch.manual_seed(0) layer = GaussianDropout(rate=0.3, is_monte_carlo=True) layer.eval() out = layer(x3d) # Output should differ from input due to noise assert not torch.allclose(out, x3d) def test_invalid_rate(self): with pytest.raises(ValueError): GaussianDropout(rate=-0.1, is_monte_carlo=True) # --------------------------------------------------------------------------- # MaxPool4D # --------------------------------------------------------------------------- class TestMaxPool4D: def test_forward_shape(self, x4d): layer = MaxPool4D(kernel_size=2, stride=2) out = layer(x4d) N, C, V, D, H, W = x4d.shape assert out.shape == (N, C, V, D // 2, H // 2, W // 2) def test_wrong_ndim(self): x = torch.ones(1, 1, 8, 8, 8) # 5-D layer = MaxPool4D(kernel_size=2) with pytest.raises(ValueError, match="6-D"): layer(x) def test_pool_v(self): x = torch.randn(1, 1, 4, 8, 8, 8) layer = MaxPool4D(kernel_size=2, stride=2, pool_v=2) out = layer(x) assert out.shape[2] == 2 # V reduced from 4 → 2 def test_gradient_flow(self, x4d): x = x4d.clone().float().requires_grad_(True) layer = MaxPool4D(kernel_size=2, stride=2) out = layer(x) out.sum().backward() assert x.grad is not None ================================================ FILE: nobrainer/tests/unit/test_losses.py ================================================ """Unit tests for nobrainer.losses (MONAI-backed).""" import pytest import torch import nobrainer.losses as losses_module # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _binary_pair(batch=2, spatial=16): """Return (y_true, y_pred) binary tensors of shape (B, 1, D, H, W).""" y_true = torch.randint(0, 2, (batch, 1, spatial, spatial, spatial)).float() y_pred = torch.sigmoid(torch.randn(batch, 1, spatial, spatial, spatial)) return y_true, y_pred def _multiclass_pair(batch=2, n_classes=3, spatial=8): """Return (y_true one-hot, y_pred softmax) tensors.""" labels = torch.randint(0, n_classes, (batch, spatial, spatial, spatial)) y_true = torch.zeros(batch, n_classes, spatial, spatial, spatial) y_true.scatter_(1, labels.unsqueeze(1), 1.0) y_pred = torch.softmax( torch.randn(batch, n_classes, spatial, spatial, spatial), dim=1 ) return y_true, y_pred # --------------------------------------------------------------------------- # dice # --------------------------------------------------------------------------- class TestDiceLoss: def test_returns_scalar(self): y_true, y_pred = _binary_pair() loss_fn = losses_module.dice(sigmoid=False) loss = loss_fn(y_pred, y_true) assert loss.ndim == 0 def test_non_negative(self): y_true, y_pred = _binary_pair() loss_fn = losses_module.dice(sigmoid=True) loss = loss_fn(y_pred, y_true) assert loss.item() >= 0.0 def test_perfect_prediction_near_zero(self): y = torch.ones(1, 1, 8, 8, 8) loss_fn = losses_module.dice() loss = loss_fn(y, y) assert loss.item() < 0.01 # --------------------------------------------------------------------------- # generalized_dice # --------------------------------------------------------------------------- class TestGeneralizedDiceLoss: def test_returns_scalar(self): y_true, y_pred = _multiclass_pair() loss_fn = losses_module.generalized_dice(softmax=False) loss = loss_fn(y_pred, y_true) assert loss.ndim == 0 def test_non_negative(self): y_true, y_pred = _multiclass_pair() loss_fn = losses_module.generalized_dice() loss = loss_fn(y_pred, y_true) assert loss.item() >= 0.0 # --------------------------------------------------------------------------- # jaccard # --------------------------------------------------------------------------- class TestJaccardLoss: def test_returns_scalar(self): y_true, y_pred = _binary_pair() loss_fn = losses_module.jaccard() loss = loss_fn(y_pred, y_true) assert loss.ndim == 0 def test_non_negative(self): y_true, y_pred = _binary_pair() loss_fn = losses_module.jaccard() loss = loss_fn(y_pred, y_true) assert loss.item() >= 0.0 # --------------------------------------------------------------------------- # tversky # --------------------------------------------------------------------------- class TestTverskyLoss: def test_returns_scalar(self): y_true, y_pred = _binary_pair() loss_fn = losses_module.tversky() loss = loss_fn(y_pred, y_true) assert loss.ndim == 0 def test_non_negative(self): y_true, y_pred = _binary_pair() loss_fn = losses_module.tversky(alpha=0.5, beta=0.5) loss = loss_fn(y_pred, y_true) assert loss.item() >= 0.0 # --------------------------------------------------------------------------- # stubs # --------------------------------------------------------------------------- class TestStubs: def test_elbo_returns_tensor(self): """elbo() is implemented in Phase 4; non-Bayesian model yields zero KL.""" import torch.nn as nn result = losses_module.elbo( nn.Linear(1, 1), kl_weight=1.0, reconstruction_loss=torch.tensor(0.5) ) assert isinstance(result, torch.Tensor) assert result.item() == pytest.approx(0.5) def test_wasserstein_returns_tensor(self): """wasserstein() is implemented in Phase 5; E[fake] - E[real].""" real_scores = torch.ones(4) fake_scores = torch.zeros(4) loss = losses_module.wasserstein(real_scores, fake_scores) assert isinstance(loss, torch.Tensor) # E[fake] - E[real] = 0 - 1 = -1 assert loss.item() == pytest.approx(-1.0) # --------------------------------------------------------------------------- # get() # --------------------------------------------------------------------------- class TestGet: def test_known_loss(self): fn = losses_module.get("dice") assert callable(fn) def test_unknown_raises(self): with pytest.raises(ValueError, match="Unknown loss"): losses_module.get("nonexistent") ================================================ FILE: nobrainer/tests/unit/test_metrics.py ================================================ """Unit tests for nobrainer.metrics (MONAI-backed).""" import pytest import torch import nobrainer.metrics as metrics_module # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _onehot_pair(batch=2, n_classes=2, spatial=8): """One-hot prediction and target tensors (B, C, D, H, W).""" labels = torch.randint(0, n_classes, (batch, spatial, spatial, spatial)) y_true = torch.zeros(batch, n_classes, spatial, spatial, spatial) y_true.scatter_(1, labels.unsqueeze(1), 1.0) y_pred = torch.softmax( torch.randn(batch, n_classes, spatial, spatial, spatial), dim=1 ) # MONAI metrics expect argmax-style binary predictions; use threshold y_pred_bin = (y_pred == y_pred.max(dim=1, keepdim=True).values).float() return y_true, y_pred_bin # --------------------------------------------------------------------------- # dice_metric # --------------------------------------------------------------------------- class TestDiceMetric: def test_instantiation(self): m = metrics_module.dice_metric() assert m is not None def test_perfect_score(self): y, _ = _onehot_pair(batch=1, n_classes=2, spatial=8) m = metrics_module.dice_metric(include_background=True) m(y_pred=y, y=y) result = m.aggregate() assert result.item() == pytest.approx(1.0, abs=1e-4) m.reset() def test_output_scalar(self): y_true, y_pred = _onehot_pair() m = metrics_module.dice_metric() m(y_pred=y_pred, y=y_true) result = m.aggregate() assert result.ndim == 0 or result.numel() == 1 # --------------------------------------------------------------------------- # jaccard_metric (MeanIoU) # --------------------------------------------------------------------------- class TestJaccardMetric: def test_instantiation(self): m = metrics_module.jaccard_metric() assert m is not None def test_perfect_score(self): y, _ = _onehot_pair(batch=1, n_classes=2, spatial=8) m = metrics_module.jaccard_metric() m(y_pred=y, y=y) result = m.aggregate() assert result.item() == pytest.approx(1.0, abs=1e-4) m.reset() # --------------------------------------------------------------------------- # hausdorff_metric # --------------------------------------------------------------------------- class TestHausdorffMetric: def test_instantiation(self): m = metrics_module.hausdorff_metric() assert m is not None def test_perfect_score_zero(self): y, _ = _onehot_pair(batch=1, n_classes=2, spatial=8) m = metrics_module.hausdorff_metric(include_background=False, percentile=95.0) m(y_pred=y, y=y) result = m.aggregate() assert result.item() == pytest.approx(0.0, abs=1e-4) m.reset() # --------------------------------------------------------------------------- # get() # --------------------------------------------------------------------------- class TestGet: def test_known_metric(self): fn = metrics_module.get("dice") assert callable(fn) def test_unknown_raises(self): with pytest.raises(ValueError, match="Unknown metric"): metrics_module.get("nonexistent") ================================================ FILE: nobrainer/tests/unit/test_model_interface.py ================================================ """Unit tests for unified model forward interface.""" from __future__ import annotations import torch from nobrainer.models import get from nobrainer.models._utils import model_supports_mc class TestUnifiedForward: """All models accept model(x) without error.""" def test_meshnet(self): model = get("meshnet")(n_classes=2, filters=8, receptive_field=37) x = torch.randn(1, 1, 16, 16, 16) out = model(x) assert out.shape == (1, 2, 16, 16, 16) def test_unet(self): model = get("unet")(n_classes=2, channels=(4, 8), strides=(2,)) x = torch.randn(1, 1, 16, 16, 16) out = model(x) assert out.shape == (1, 2, 16, 16, 16) def test_segformer3d(self): model = get("segformer3d")(n_classes=2, embed_dims=(16, 32, 80, 128)) model.eval() x = torch.randn(1, 1, 32, 32, 32) with torch.no_grad(): out = model(x) assert out.shape == (1, 2, 32, 32, 32) class TestMcSupport: """Bayesian models support mc parameter.""" def test_kwyk_meshnet_supports_mc(self): model = get("kwyk_meshnet")(n_classes=2, filters=8, receptive_field=37) assert model_supports_mc(model) x = torch.randn(1, 1, 16, 16, 16) out_det = model(x, mc=False) out_mc = model(x, mc=True) assert out_det.shape == (1, 2, 16, 16, 16) assert out_mc.shape == (1, 2, 16, 16, 16) def test_bayesian_meshnet_supports_mc(self): import pyro pyro.clear_param_store() model = get("bayesian_meshnet")(n_classes=2, filters=8, receptive_field=37) assert model_supports_mc(model) x = torch.randn(1, 1, 16, 16, 16) with pyro.poutine.trace(): out = model(x) assert out.shape == (1, 2, 16, 16, 16) def test_regular_model_no_mc(self): model = get("meshnet")(n_classes=2, filters=8, receptive_field=37) assert not model_supports_mc(model) def test_forward_helper_uses_explicit_check(self): """_forward does NOT use try/except TypeError.""" import inspect from nobrainer.prediction import _forward source = inspect.getsource(_forward) assert "except TypeError" not in source assert "model_supports_mc" in source ================================================ FILE: nobrainer/tests/unit/test_model_registry.py ================================================ """Unit tests for SwinUNETR and SegResNet model registration.""" from __future__ import annotations import torch from nobrainer.models import get class TestSwinUNETR: def test_instantiate(self): model = get("swin_unetr")(n_classes=2, feature_size=12) assert model is not None def test_output_shape(self): model = get("swin_unetr")(n_classes=3, feature_size=12) model.eval() # SwinUNETR needs input >= 64³ due to window attention + instance norm x = torch.randn(1, 1, 64, 64, 64) with torch.no_grad(): out = model(x) assert out.shape == (1, 3, 64, 64, 64) class TestSegResNet: def test_instantiate(self): model = get("segresnet")(n_classes=2, init_filters=8) assert model is not None def test_output_shape(self): model = get("segresnet")(n_classes=5, init_filters=8, blocks_down=(1, 2, 2, 4)) x = torch.randn(1, 1, 32, 32, 32) out = model(x) assert out.shape == (1, 5, 32, 32, 32) class TestRegistryAccess: def test_swin_unetr_in_registry(self): from nobrainer.models import available_models assert "swin_unetr" in available_models() def test_segresnet_in_registry(self): from nobrainer.models import available_models assert "segresnet" in available_models() ================================================ FILE: nobrainer/tests/unit/test_models_segmentation.py ================================================ """Unit tests for nobrainer segmentation models (PyTorch).""" import pytest import torch from nobrainer.models import get as get_model from nobrainer.models.autoencoder import autoencoder from nobrainer.models.highresnet import highresnet from nobrainer.models.meshnet import meshnet from nobrainer.models.segmentation import attention_unet, unet, unetr, vnet from nobrainer.models.simsiam import simsiam # Small spatial size to keep tests fast on CPU SPATIAL = 32 IN_SHAPE = (1, 1, SPATIAL, SPATIAL, SPATIAL) def _grad_check(model: torch.nn.Module, inp: torch.Tensor) -> bool: """Return True if gradients flow through all parameters.""" model.train() out = model(inp) if isinstance(out, tuple): loss = sum(o.mean() for o in out) else: loss = out.mean() loss.backward() return all(p.grad is not None for p in model.parameters() if p.requires_grad) # --------------------------------------------------------------------------- # UNet (MONAI) # --------------------------------------------------------------------------- class TestUNet: def test_output_shape_binary(self): m = unet(n_classes=1) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 1, SPATIAL, SPATIAL, SPATIAL) def test_output_shape_multiclass(self): m = unet(n_classes=3) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 3, SPATIAL, SPATIAL, SPATIAL) def test_gradient_flow(self): m = unet(n_classes=2) x = torch.randn(*IN_SHAPE) assert _grad_check(m, x) def test_get_registry(self): fn = get_model("unet") assert fn is unet # --------------------------------------------------------------------------- # VNet (MONAI) # --------------------------------------------------------------------------- class TestVNet: def test_output_shape(self): m = vnet(n_classes=1) x = torch.randn(*IN_SHAPE) out = m(x) assert out.shape == (1, 1, SPATIAL, SPATIAL, SPATIAL) def test_gradient_flow(self): m = vnet(n_classes=2) x = torch.randn(*IN_SHAPE) assert _grad_check(m, x) # --------------------------------------------------------------------------- # Attention UNet (MONAI) # --------------------------------------------------------------------------- class TestUNETR: def test_output_shape(self): m = unetr( n_classes=2, img_size=(SPATIAL, SPATIAL, SPATIAL), hidden_size=192, mlp_dim=768, num_heads=12, feature_size=8, ) x = torch.randn(1, 1, SPATIAL, SPATIAL, SPATIAL) m.eval() with torch.no_grad(): out = m(x) assert out.shape == (1, 2, SPATIAL, SPATIAL, SPATIAL) class TestAttentionUNet: def test_output_shape(self): m = attention_unet( n_classes=1, channels=(8, 16, 32), strides=(2, 2), ) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 1, SPATIAL, SPATIAL, SPATIAL) def test_gradient_flow(self): m = attention_unet( n_classes=2, channels=(8, 16, 32), strides=(2, 2), ) x = torch.randn(*IN_SHAPE) assert _grad_check(m, x) # --------------------------------------------------------------------------- # MeshNet (custom PyTorch) # --------------------------------------------------------------------------- class TestMeshNet: def test_output_shape_binary(self): m = meshnet(n_classes=1) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 1, SPATIAL, SPATIAL, SPATIAL) def test_output_shape_multiclass(self): m = meshnet(n_classes=3) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 3, SPATIAL, SPATIAL, SPATIAL) def test_receptive_field_37(self): m = meshnet(n_classes=1, receptive_field=37) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 1, SPATIAL, SPATIAL, SPATIAL) def test_receptive_field_129(self): m = meshnet(n_classes=1, receptive_field=129) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 1, SPATIAL, SPATIAL, SPATIAL) def test_invalid_rf(self): with pytest.raises(ValueError, match="receptive_field"): meshnet(n_classes=1, receptive_field=999) def test_gradient_flow(self): m = meshnet(n_classes=2) x = torch.randn(*IN_SHAPE) assert _grad_check(m, x) # --------------------------------------------------------------------------- # HighResNet (custom PyTorch) # --------------------------------------------------------------------------- class TestHighResNet: def test_output_shape_binary(self): m = highresnet(n_classes=1) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 1, SPATIAL, SPATIAL, SPATIAL) def test_output_shape_multiclass(self): m = highresnet(n_classes=3) x = torch.randn(*IN_SHAPE) assert m(x).shape == (1, 3, SPATIAL, SPATIAL, SPATIAL) def test_gradient_flow(self): m = highresnet(n_classes=2) x = torch.randn(*IN_SHAPE) assert _grad_check(m, x) # --------------------------------------------------------------------------- # Autoencoder (custom PyTorch) # --------------------------------------------------------------------------- class TestAutoencoder: # Use batch=2 to avoid BatchNorm single-sample issues def test_output_shape(self): m = autoencoder(input_shape=(SPATIAL, SPATIAL, SPATIAL), encoding_dim=64) x = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) out = m(x) assert out.shape == x.shape def test_encode_shape(self): m = autoencoder(input_shape=(SPATIAL, SPATIAL, SPATIAL), encoding_dim=64) x = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) z = m.encode(x) assert z.shape == (2, 64) def test_gradient_flow(self): m = autoencoder(input_shape=(SPATIAL, SPATIAL, SPATIAL), encoding_dim=32) x = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) assert _grad_check(m, x) # --------------------------------------------------------------------------- # SimSiam (custom PyTorch) # --------------------------------------------------------------------------- class TestSimSiam: # Use batch=2 to avoid BatchNorm1d single-sample issues def test_forward_shapes(self): m = simsiam(projection_dim=128, latent_dim=64) x1 = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) x2 = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) p1, p2, z1, z2 = m(x1, x2) assert p1.shape == (2, 128) assert z1.shape == (2, 128) def test_loss_negative_range(self): m = simsiam(projection_dim=128, latent_dim=64) x1 = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) x2 = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) p1, p2, z1, z2 = m(x1, x2) loss = m.loss(p1, p2, z1, z2) # Loss should be in [-1, 0] for cosine similarity assert -1.1 <= loss.item() <= 0.1 def test_gradient_flow(self): m = simsiam(projection_dim=128, latent_dim=64) m.train() x1 = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) x2 = torch.randn(2, 1, SPATIAL, SPATIAL, SPATIAL) p1, p2, z1, z2 = m(x1, x2) loss = m.loss(p1, p2, z1, z2) loss.backward() assert all( p.grad is not None for p in m.projector.parameters() if p.requires_grad ) ================================================ FILE: nobrainer/tests/unit/test_prediction.py ================================================ """Unit tests for predict() and predict_with_uncertainty().""" from __future__ import annotations from pathlib import Path import tempfile import nibabel as nib import numpy as np import torch.nn as nn from nobrainer.prediction import predict, predict_with_uncertainty # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class _IdentityModel(nn.Module): """Minimal 1-class model: sigmoid of a 1×1×1 conv applied to input.""" def __init__(self): super().__init__() self.conv = nn.Conv3d(1, 1, kernel_size=1) def forward(self, x): return self.conv(x) class _MultiClassModel(nn.Module): """Minimal 3-class model.""" def __init__(self): super().__init__() self.conv = nn.Conv3d(1, 3, kernel_size=1) def forward(self, x): return self.conv(x) def _make_nifti(shape=(32, 32, 32), tmp_path=None) -> str: if tmp_path is None: tmp_path = Path(tempfile.mkdtemp()) data = np.random.rand(*shape).astype(np.float32) img = nib.Nifti1Image(data, np.eye(4)) path = str(tmp_path / f"vol_{np.random.randint(0, 1e6)}.nii.gz") nib.save(img, path) return path # --------------------------------------------------------------------------- # predict() # --------------------------------------------------------------------------- class TestPredict: def test_returns_nifti(self, tmp_path): path = _make_nifti((16, 16, 16), tmp_path) model = _IdentityModel() out = predict(path, model, block_shape=(8, 8, 8), batch_size=2) assert isinstance(out, nib.Nifti1Image) def test_output_shape_matches_input(self, tmp_path): path = _make_nifti((16, 16, 16), tmp_path) model = _IdentityModel() out = predict(path, model, block_shape=(8, 8, 8), batch_size=2) assert out.shape == (16, 16, 16) def test_ndarray_input(self): arr = np.random.rand(16, 16, 16).astype(np.float32) model = _IdentityModel() out = predict(arr, model, block_shape=(8, 8, 8), batch_size=2) assert out.shape == (16, 16, 16) def test_nifti_image_input(self): arr = np.random.rand(16, 16, 16).astype(np.float32) img = nib.Nifti1Image(arr, np.eye(4)) model = _IdentityModel() out = predict(img, model, block_shape=(8, 8, 8), batch_size=2) assert out.shape == (16, 16, 16) def test_affine_preserved(self, tmp_path): path = _make_nifti((16, 16, 16), tmp_path) model = _IdentityModel() src_affine = nib.load(path).affine out = predict(path, model, block_shape=(8, 8, 8), batch_size=2) assert np.allclose(out.affine, src_affine) def test_return_probabilities(self): arr = np.random.rand(16, 16, 16).astype(np.float32) model = _MultiClassModel() out = predict( arr, model, block_shape=(8, 8, 8), batch_size=2, return_labels=False ) # 3-class probabilities → shape (3, D, H, W) assert out.shape[:1] == (3,) or out.ndim == 4 def test_non_block_aligned_input(self): """Volume with shape not divisible by block_shape should still work.""" arr = np.random.rand(20, 20, 20).astype(np.float32) model = _IdentityModel() out = predict(arr, model, block_shape=(8, 8, 8), batch_size=2) assert out.shape == (20, 20, 20) # --------------------------------------------------------------------------- # predict_with_uncertainty() # --------------------------------------------------------------------------- class TestPredictWithUncertainty: def test_returns_three_niftis(self): arr = np.random.rand(16, 16, 16).astype(np.float32) model = _IdentityModel() label, var, entropy = predict_with_uncertainty( arr, model, n_samples=3, block_shape=(8, 8, 8), batch_size=2 ) assert isinstance(label, nib.Nifti1Image) assert isinstance(var, nib.Nifti1Image) assert isinstance(entropy, nib.Nifti1Image) def test_output_shapes_match_input(self): arr = np.random.rand(16, 16, 16).astype(np.float32) model = _IdentityModel() label, var, entropy = predict_with_uncertainty( arr, model, n_samples=3, block_shape=(8, 8, 8), batch_size=2 ) assert label.shape == (16, 16, 16) assert var.shape == (16, 16, 16) assert entropy.shape == (16, 16, 16) def test_variance_nonnegative(self): arr = np.random.rand(16, 16, 16).astype(np.float32) model = _IdentityModel() _, var, _ = predict_with_uncertainty( arr, model, n_samples=3, block_shape=(8, 8, 8), batch_size=2 ) assert (np.asarray(var.dataobj) >= 0).all() def test_entropy_nonnegative(self): arr = np.random.rand(16, 16, 16).astype(np.float32) model = _IdentityModel() _, _, entropy = predict_with_uncertainty( arr, model, n_samples=3, block_shape=(8, 8, 8), batch_size=2 ) assert (np.asarray(entropy.dataobj) >= 0).all() ================================================ FILE: nobrainer/tests/unit/test_research_commit.py ================================================ """Unit tests for commit_best_model in nobrainer.research.loop.""" from __future__ import annotations import json from pathlib import Path from unittest.mock import MagicMock, patch import pytest def _make_model_files(tmp_path: Path) -> tuple[Path, Path]: """Create dummy model and config files.""" model_path = tmp_path / "best_model.pth" model_path.write_bytes(b"\x00" * 16) # dummy weights config_path = tmp_path / "best_config.json" config_path.write_text(json.dumps({"learning_rate": 1e-4, "batch_size": 4})) return model_path, config_path class TestCommitBestModel: def test_directory_structure_created(self, tmp_path): """commit_best_model creates the expected subdirectory.""" model_path, config_path = _make_model_files(tmp_path) trained_models = tmp_path / "trained_models" trained_models.mkdir() mock_dl = MagicMock() mock_dl.save = MagicMock() mock_dl.push = MagicMock() with patch.dict( "sys.modules", {"datalad": MagicMock(), "datalad.api": mock_dl} ): from nobrainer.research.loop import commit_best_model result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models, model_family="bayesian_vnet", val_dice=0.85, source_run_id="run_001", ) dest = Path(result["path"]) assert dest.exists() assert (dest / "model.pth").exists() assert (dest / "config.json").exists() assert (dest / "model_card.md").exists() def test_model_card_contains_required_fields(self, tmp_path): """model_card.md contains architecture, val_dice, source_run_id.""" model_path, config_path = _make_model_files(tmp_path) trained_models = tmp_path / "trained_models" trained_models.mkdir() mock_dl = MagicMock() with patch.dict( "sys.modules", {"datalad": MagicMock(), "datalad.api": mock_dl} ): from nobrainer.research.loop import commit_best_model result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models, model_family="bayesian_vnet", val_dice=0.85, source_run_id="run_42", ) card = (Path(result["path"]) / "model_card.md").read_text() assert "bayesian_vnet" in card assert "0.8500" in card assert "run_42" in card assert "PyTorch" in card def test_model_version_dict_fields(self, tmp_path): """commit_best_model returns ModelVersion dict with expected keys.""" model_path, config_path = _make_model_files(tmp_path) trained_models = tmp_path / "trained_models" trained_models.mkdir() mock_dl = MagicMock() with patch.dict( "sys.modules", {"datalad": MagicMock(), "datalad.api": mock_dl} ): from nobrainer.research.loop import commit_best_model result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models, model_family="bayesian_vnet", val_dice=0.75, ) assert "path" in result assert "datalad_commit" in result assert "val_dice" in result assert "model_family" in result assert result["val_dice"] == pytest.approx(0.75) assert result["model_family"] == "bayesian_vnet" def test_datalad_commit_message_in_result(self, tmp_path): """commit_best_model result contains a descriptive datalad_commit message.""" model_path, config_path = _make_model_files(tmp_path) trained_models = tmp_path / "trained_models" trained_models.mkdir() mock_dl = MagicMock() with patch.dict( "sys.modules", {"datalad": MagicMock(), "datalad.api": mock_dl} ): from nobrainer.research.loop import commit_best_model result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models, model_family="bayesian_vnet", val_dice=0.9, ) assert "bayesian_vnet" in result["datalad_commit"] assert "0.9000" in result["datalad_commit"] def test_result_contains_osf_url_key(self, tmp_path): """commit_best_model result always contains the osf_url key.""" model_path, config_path = _make_model_files(tmp_path) trained_models = tmp_path / "trained_models" trained_models.mkdir() with patch.dict( "sys.modules", {"datalad": MagicMock(), "datalad.api": MagicMock()} ): from nobrainer.research.loop import commit_best_model result = commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models, model_family="bayesian_vnet", val_dice=0.8, ) # osf_url is present; it is either 'osf://' (push succeeded) or None assert "osf_url" in result def test_datalad_not_installed_raises_import_error(self, tmp_path): """ImportError raised with helpful message when datalad missing.""" model_path, config_path = _make_model_files(tmp_path) trained_models = tmp_path / "trained_models" trained_models.mkdir() with patch.dict("sys.modules", {"datalad": None, "datalad.api": None}): from nobrainer.research.loop import commit_best_model with pytest.raises(ImportError, match="datalad"): commit_best_model( best_model_path=model_path, best_config_path=config_path, trained_models_path=trained_models, model_family="bayesian_vnet", val_dice=0.8, ) ================================================ FILE: nobrainer/tests/unit/test_research_loop.py ================================================ """Unit tests for the autoresearch run_loop.""" from __future__ import annotations import json from pathlib import Path from unittest.mock import patch import pytest from nobrainer.research.loop import ( ExperimentResult, _classify_failure, _has_nan, _parse_config_comment, _patch_config, _read_val_dice, _write_summary, run_loop, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _write_train_script(path: Path, config: dict | None = None) -> None: cfg = config or {"learning_rate": 1e-4, "batch_size": 4} path.write_text( f"# CONFIG: {json.dumps(cfg)}\n" "import sys; print('training done'); sys.exit(0)\n" ) def _write_val_dice(path: Path, val_dice: float) -> None: path.write_text(json.dumps({"val_dice": val_dice})) # --------------------------------------------------------------------------- # Unit helpers # --------------------------------------------------------------------------- class TestHelpers: def test_parse_config_comment(self, tmp_path): script = tmp_path / "train.py" script.write_text("# CONFIG: {\"lr\": 1e-4}\nprint('hi')\n") config = _parse_config_comment(script) assert config["lr"] == pytest.approx(1e-4) def test_parse_config_comment_missing(self, tmp_path): script = tmp_path / "train.py" script.write_text("print('no config')\n") config = _parse_config_comment(script) assert config == {} def test_patch_config(self, tmp_path): script = tmp_path / "train.py" script.write_text("# CONFIG: {\"lr\": 1e-4}\nprint('hi')\n") _patch_config(script, {"lr": 5e-4}) content = script.read_text() assert ( '"lr": 0.0005' in content or '"lr": 5e-4' in content or "5e-04" in content ) def test_patch_config_adds_when_missing(self, tmp_path): script = tmp_path / "train.py" script.write_text("print('no config')\n") _patch_config(script, {"lr": 1e-3}) content = script.read_text() assert "# CONFIG:" in content def test_read_val_dice_valid(self, tmp_path): (tmp_path / "val_dice.json").write_text('{"val_dice": 0.85}') assert _read_val_dice(tmp_path / "val_dice.json") == pytest.approx(0.85) def test_read_val_dice_missing(self, tmp_path): assert _read_val_dice(tmp_path / "nonexistent.json") is None def test_has_nan(self): assert _has_nan("loss: nan after epoch 3") assert not _has_nan("loss: 0.25") def test_classify_failure_oom(self): assert _classify_failure("CUDA out of memory") == "CUDA OOM" def test_classify_failure_nan(self): assert _classify_failure("nan in grad") == "NaN in loss" def test_classify_failure_generic(self): assert _classify_failure("some error") == "non-zero exit code" def test_write_summary(self, tmp_path): results = [ ExperimentResult(0, {}, 0.8, "improved"), ExperimentResult(1, {}, 0.79, "degraded"), ] _write_summary(tmp_path, results, "bayesian_vnet", 0.8) summary = (tmp_path / "run_summary.md").read_text() assert "bayesian_vnet" in summary assert "0.8000" in summary # --------------------------------------------------------------------------- # run_loop integration tests (subprocess mocked) # --------------------------------------------------------------------------- class TestRunLoop: def test_keep_improved_experiment(self, tmp_path): """run_loop keeps config when val_dice improves.""" _write_train_script(tmp_path / "train.py") _write_val_dice(tmp_path / "val_dice.json", 0.9) with ( patch( "nobrainer.research.loop._propose_config", side_effect=[ {"learning_rate": 5e-4, "batch_size": 4}, ] * 5, ), patch( "nobrainer.research.loop.subprocess.run", ) as mock_run, ): mock_run.return_value.returncode = 0 mock_run.return_value.stdout = "training done\n" mock_run.return_value.stderr = "" results = run_loop( tmp_path, max_experiments=2, budget_hours=1.0, ) improved = [r for r in results if r.outcome == "improved"] assert len(improved) >= 1 def test_revert_on_degraded(self, tmp_path): """run_loop reverts train.py when val_dice degrades.""" original_content = ( f"# CONFIG: {json.dumps({'learning_rate': 1e-4, 'batch_size': 4})}\n" "import sys; sys.exit(0)\n" ) (tmp_path / "train.py").write_text(original_content) # First experiment improves, second degrades dices = [0.8, 0.7] call_count = [0] def _mock_run(cmd, **kwargs): from unittest.mock import MagicMock dice = dices[call_count[0] % len(dices)] _write_val_dice(tmp_path / "val_dice.json", dice) call_count[0] += 1 r = MagicMock() r.returncode = 0 r.stdout = "done\n" r.stderr = "" return r with ( patch("nobrainer.research.loop.subprocess.run", side_effect=_mock_run), patch( "nobrainer.research.loop._propose_config", side_effect=[{"learning_rate": 5e-4}] * 5, ), ): results = run_loop(tmp_path, max_experiments=2, budget_hours=1.0) degraded = [r for r in results if r.outcome == "degraded"] assert len(degraded) >= 1 def test_failure_handling_reverts(self, tmp_path): """run_loop reverts train.py when subprocess fails.""" _write_train_script(tmp_path / "train.py") original = (tmp_path / "train.py").read_text() with ( patch( "nobrainer.research.loop.subprocess.run", ) as mock_run, patch( "nobrainer.research.loop._propose_config", return_value={"learning_rate": 1e-3}, ), ): mock_run.return_value.returncode = 1 mock_run.return_value.stdout = "" mock_run.return_value.stderr = "some error" results = run_loop(tmp_path, max_experiments=1, budget_hours=1.0) assert results[0].outcome == "failed" # Train script reverted assert (tmp_path / "train.py").read_text() == original def test_run_summary_written(self, tmp_path): """run_summary.md is written after the loop.""" _write_train_script(tmp_path / "train.py") with ( patch( "nobrainer.research.loop.subprocess.run", ) as mock_run, patch( "nobrainer.research.loop._propose_config", return_value={"learning_rate": 1e-4}, ), ): mock_run.return_value.returncode = 1 mock_run.return_value.stdout = "" mock_run.return_value.stderr = "error" run_loop(tmp_path, max_experiments=1, budget_hours=1.0) assert (tmp_path / "run_summary.md").exists() def test_missing_train_script_raises(self, tmp_path): with pytest.raises(FileNotFoundError): run_loop(tmp_path, max_experiments=1, budget_hours=1.0) def test_budget_seconds_terminates_quickly(self, tmp_path): """T013: budget_seconds=10 should terminate within 15s.""" import time (tmp_path / "train.py").write_text( "import json, time; time.sleep(0.1);\n" 'json.dump({"val_dice": 0.5}, open("val_dice.json", "w"))\n' ) start = time.time() with patch( "nobrainer.research.loop._propose_config", return_value={}, ): run_loop( tmp_path, max_experiments=100, budget_seconds=5, ) elapsed = time.time() - start assert elapsed < 15, f"Loop took {elapsed:.1f}s, expected < 15s" ================================================ FILE: nobrainer/tests/unit/test_segformer3d.py ================================================ """Unit tests for SegFormer3D model.""" from __future__ import annotations import torch from nobrainer.models import get from nobrainer.models.segformer3d import SegFormer3D class TestSegFormer3DShapes: def test_output_shape_32(self): model = SegFormer3D(n_classes=2, embed_dims=(16, 32, 80, 128)) model.eval() x = torch.randn(1, 1, 32, 32, 32) with torch.no_grad(): out = model(x) assert out.shape == (1, 2, 32, 32, 32) def test_output_shape_64(self): model = SegFormer3D(n_classes=5, embed_dims=(16, 32, 80, 128)) model.eval() x = torch.randn(1, 1, 64, 64, 64) with torch.no_grad(): out = model(x) assert out.shape == (1, 5, 64, 64, 64) def test_batch_size_2(self): model = SegFormer3D(n_classes=3, embed_dims=(16, 32, 80, 128)) model.eval() x = torch.randn(2, 1, 32, 32, 32) with torch.no_grad(): out = model(x) assert out.shape == (2, 3, 32, 32, 32) class TestSegFormer3DParams: def test_default_param_count(self): """Default (small) config should have ~4-5M params.""" model = SegFormer3D(n_classes=50) n_params = sum(p.numel() for p in model.parameters()) assert n_params < 10_000_000 # < 10M def test_tiny_param_count(self): """Tiny config should have ~1-2M params.""" model = SegFormer3D(n_classes=50, embed_dims=(16, 32, 80, 128)) n_params = sum(p.numel() for p in model.parameters()) assert n_params < 5_000_000 # < 5M def test_base_param_count(self): """Base config should have ~15-20M params.""" model = SegFormer3D(n_classes=50, embed_dims=(64, 128, 320, 512)) n_params = sum(p.numel() for p in model.parameters()) assert n_params > 10_000_000 # > 10M class TestSegFormer3DRegistry: def test_accessible_via_get(self): model = get("segformer3d")(n_classes=2, embed_dims=(16, 32, 80, 128)) assert model is not None assert isinstance(model, SegFormer3D) def test_in_available_models(self): from nobrainer.models import available_models assert "segformer3d" in available_models() def test_factory_defaults(self): from nobrainer.models.segformer3d import segformer3d model = segformer3d(n_classes=2) assert isinstance(model, SegFormer3D) ================================================ FILE: nobrainer/tests/unit/test_slurm.py ================================================ """Unit tests for nobrainer.slurm utilities.""" from __future__ import annotations import torch from nobrainer.slurm import SlurmPreemptionHandler, load_checkpoint, save_checkpoint class TestSlurmPreemptionHandler: def test_initial_state(self): h = SlurmPreemptionHandler() assert h.preempted is False def test_is_slurm_job(self): # In test environment, should be False assert isinstance(SlurmPreemptionHandler.is_slurm_job(), bool) class TestCheckpoint: def test_save_and_load(self, tmp_path): model = torch.nn.Linear(4, 2) opt = torch.optim.SGD(model.parameters(), lr=0.01) metrics = {"train_losses": [1.0, 0.5], "best_loss": 0.5} save_checkpoint(tmp_path, model, opt, epoch=5, metrics=metrics) assert (tmp_path / "checkpoint.pt").exists() assert (tmp_path / "checkpoint_meta.json").exists() model2 = torch.nn.Linear(4, 2) opt2 = torch.optim.SGD(model2.parameters(), lr=0.01) start, restored = load_checkpoint(tmp_path, model2, opt2) assert start == 6 # next epoch assert restored["best_loss"] == 0.5 assert len(restored["train_losses"]) == 2 def test_load_no_checkpoint(self, tmp_path): model = torch.nn.Linear(4, 2) start, metrics = load_checkpoint(tmp_path, model) assert start == 0 assert metrics == {} def test_model_weights_restored(self, tmp_path): model = torch.nn.Linear(4, 2) model.weight.data.fill_(42.0) opt = torch.optim.SGD(model.parameters(), lr=0.01) save_checkpoint(tmp_path, model, opt, epoch=0) model2 = torch.nn.Linear(4, 2) load_checkpoint(tmp_path, model2) assert torch.allclose(model2.weight.data, torch.tensor(42.0)) ================================================ FILE: nobrainer/tests/unit/test_stride_patches.py ================================================ """Unit tests for strided patch extraction and reassembly.""" from __future__ import annotations import numpy as np from nobrainer.prediction import reassemble_predictions, strided_patch_positions class TestStridedPatchPositions: def test_non_overlapping_count(self): """256³ with block=32 stride=32 → 8³ = 512 patches.""" pos = strided_patch_positions((256, 256, 256), (32, 32, 32), (32, 32, 32)) assert len(pos) == 8 * 8 * 8 # 512 def test_overlapping_more_patches(self): """Stride < block produces more patches.""" non_overlap = strided_patch_positions((64, 64, 64), (32, 32, 32), (32, 32, 32)) overlap = strided_patch_positions((64, 64, 64), (32, 32, 32), (16, 16, 16)) assert len(overlap) > len(non_overlap) def test_patch_shapes_valid(self): """Each position should yield a valid slice.""" pos = strided_patch_positions((100, 100, 100), (32, 32, 32), (16, 16, 16)) for sd, sh, sw in pos: assert sd.stop - sd.start == 32 assert sh.stop - sh.start == 32 assert sw.stop - sw.start == 32 assert sd.stop <= 100 assert sh.stop <= 100 assert sw.stop <= 100 def test_stride_equals_block_default(self): """None stride defaults to block_shape.""" pos = strided_patch_positions((64, 64, 64), (32, 32, 32)) assert len(pos) == 2 * 2 * 2 # 8 class TestReassemblePredictions: def test_non_overlapping_perfect_reconstruction(self): """Non-overlapping patches reassemble perfectly.""" vol_shape = (64, 64, 64) block = (32, 32, 32) n_classes = 2 # Create a known volume original = np.random.randn(n_classes, *vol_shape).astype(np.float32) # Extract non-overlapping patches positions = strided_patch_positions(vol_shape, block, block) patches = [] for sd, sh, sw in positions: patch = original[:, sd, sh, sw] patches.append((patch, (sd, sh, sw))) # Reassemble result = reassemble_predictions(patches, vol_shape, n_classes) assert np.allclose(result, original, atol=1e-6) def test_overlapping_average(self): """Overlapping patches with averaging should still reconstruct reasonably.""" vol_shape = (64, 64, 64) block = (32, 32, 32) stride = (16, 16, 16) n_classes = 2 # Create constant volume (averaging constant = constant) original = np.ones((n_classes, *vol_shape), dtype=np.float32) * 0.5 positions = strided_patch_positions(vol_shape, block, stride) patches = [] for sd, sh, sw in positions: patch = original[:, sd, sh, sw] patches.append((patch, (sd, sh, sw))) result = reassemble_predictions( patches, vol_shape, n_classes, strategy="average" ) assert np.allclose(result, 0.5, atol=1e-5) def test_output_shape(self): """Output shape matches volume_shape.""" patches = [ (np.ones((3, 16, 16, 16)), (slice(0, 16), slice(0, 16), slice(0, 16))) ] result = reassemble_predictions(patches, (32, 32, 32), 3) assert result.shape == (3, 32, 32, 32) ================================================ FILE: nobrainer/tests/unit/test_synthseg.py ================================================ """Unit tests for enhanced SynthSeg generator.""" from __future__ import annotations from pathlib import Path import nibabel as nib import numpy as np import pytest import torch def _make_label_map(tmp_path: Path, shape=(32, 32, 32)) -> str: """Create a simple label map with a few regions.""" arr = np.zeros(shape, dtype=np.int32) # Background = 0, WM = 2, GM = 3, CSF = 4, hippocampus L = 17, R = 53 arr[4:28, 4:28, 4:28] = 2 # WM core arr[6:26, 6:26, 6:26] = 3 # GM shell arr[12:20, 12:20, 12:20] = 4 # CSF center arr[8:12, 8:12, 8:16] = 17 # L hippocampus arr[8:12, 8:12, 16:24] = 53 # R hippocampus path = str(tmp_path / "label.nii.gz") nib.save(nib.Nifti1Image(arr, np.eye(4)), path) return path class TestTissueClasses: def test_all_50class_labels_covered(self): from nobrainer.data.tissue_classes import FREESURFER_TISSUE_CLASSES all_ids = set() for ids in FREESURFER_TISSUE_CLASSES.values(): all_ids.update(ids) # Should cover background + major structures assert 0 in all_ids # background assert 2 in all_ids # L cerebral WM assert 41 in all_ids # R cerebral WM assert 17 in all_ids # L hippocampus assert 53 in all_ids # R hippocampus def test_no_label_in_multiple_classes(self): from nobrainer.data.tissue_classes import FREESURFER_TISSUE_CLASSES seen = {} for cls_name, ids in FREESURFER_TISSUE_CLASSES.items(): for lid in ids: assert ( lid not in seen ), f"Label {lid} in both '{seen[lid]}' and '{cls_name}'" seen[lid] = cls_name class TestGMMGrouping: def test_within_class_same_distribution(self, tmp_path): from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=1, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=False, noise_std=0, bias_field_std=0, ) sample = gen[0] image = sample["image"][0].numpy() # (D, H, W) label = sample["label"][0].numpy() # L hippocampus (17) and R hippocampus (53) are both in "hippocampus" class # They should have similar mean intensities (same GMM class) l_hip = image[label == 17] r_hip = image[label == 53] if len(l_hip) > 0 and len(r_hip) > 0: # Both drawn from same distribution — means should be close mean_diff = abs(l_hip.mean() - r_hip.mean()) pooled_std = max(l_hip.std(), r_hip.std(), 1e-6) cv = mean_diff / pooled_std assert cv < 0.5 # within-class similarity def test_different_classes_differ(self, tmp_path): from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=1, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=False, noise_std=0, bias_field_std=0, ) sample = gen[0] image = sample["image"][0].numpy() label = sample["label"][0].numpy() wm = image[label == 2] csf = image[label == 4] # WM and CSF should have different distributions (different classes) if len(wm) > 10 and len(csf) > 10: # Not guaranteed to differ every time but very likely assert wm.mean() != pytest.approx(csf.mean(), abs=1.0) def test_two_runs_produce_different_intensities(self, tmp_path): from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=2, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=False, ) s1 = gen[0]["image"] s2 = gen[1]["image"] assert not torch.allclose(s1, s2) class TestSpatialAugmentation: def test_elastic_changes_geometry(self, tmp_path): from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=1, elastic_std=4.0, rotation_range=0, flipping=False, randomize_resolution=False, noise_std=0, bias_field_std=0, ) sample = gen[0] label = sample["label"][0].numpy() # Load original label for comparison orig = np.asarray(nib.load(path).dataobj, dtype=np.int32) # Elastic deformation should change some voxel positions changed = (label != orig).sum() total = orig.size assert changed / total > 0.01 # at least 1% changed def test_label_nearest_neighbor(self, tmp_path): """Labels should remain integer-valued after spatial augmentation.""" from nobrainer.augmentation.synthseg import SynthSegGenerator from nobrainer.data.tissue_classes import FREESURFER_LR_PAIRS path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=1, elastic_std=4.0, rotation_range=15.0, randomize_resolution=False, ) sample = gen[0] label = sample["label"][0].numpy() # All values should be valid integers (no interpolation artifacts) # Include L/R swapped labels since flipping may have occurred orig_labels = set(np.asarray(nib.load(path).dataobj, dtype=np.int32).flat) valid_labels = set(orig_labels) for left, right in FREESURFER_LR_PAIRS: if left in orig_labels: valid_labels.add(right) if right in orig_labels: valid_labels.add(left) actual_labels = set(label.flat) assert actual_labels.issubset(valid_labels) def test_flipping_swaps_lr(self, tmp_path): """Flipping should swap L/R FreeSurfer codes.""" from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=30, elastic_std=0, rotation_range=0, flipping=True, randomize_resolution=False, noise_std=0, bias_field_std=0, ) # After L/R flip, label 17 (L hippocampus) should become 53 and vice versa # Check that at least one sample has the swap in the label set found_swap = False orig = np.asarray(nib.load(path).dataobj, dtype=np.int32) for i in range(30): label = gen[i]["label"][0].numpy() # A flip swaps L/R labels AND mirrors spatially. # If the spatial distribution of label 17 differs from original, flip happened orig_17_count = (orig == 17).sum() new_17_count = (label == 17).sum() if orig_17_count > 0 and new_17_count != orig_17_count: found_swap = True break assert found_swap class TestResolutionRandomization: def test_blurs_image(self, tmp_path): from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen_sharp = SynthSegGenerator( [path], n_samples_per_map=1, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=False, noise_std=0, bias_field_std=0, ) gen_blur = SynthSegGenerator( [path], n_samples_per_map=1, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=True, resolution_range=(2.0, 3.0), # force heavy blur noise_std=0, bias_field_std=0, ) # Use same seed for intensity but different resolution np.random.seed(42) sharp = gen_sharp[0]["image"][0].numpy() np.random.seed(42) blurred = gen_blur[0]["image"][0].numpy() # Blurred should have less high-frequency energy sharp_grad = np.abs(np.diff(sharp, axis=0)).mean() blur_grad = np.abs(np.diff(blurred, axis=0)).mean() assert blur_grad < sharp_grad class TestOutputFormat: def test_returns_dict_with_correct_keys(self, tmp_path): from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen = SynthSegGenerator([path], n_samples_per_map=1) sample = gen[0] assert "image" in sample assert "label" in sample assert sample["image"].shape[0] == 1 # channel dim assert sample["label"].shape[0] == 1 assert sample["image"].dtype == torch.float32 assert sample["label"].dtype == torch.int64 def test_correct_length(self, tmp_path): from nobrainer.augmentation.synthseg import SynthSegGenerator path = _make_label_map(tmp_path) gen = SynthSegGenerator([path, path], n_samples_per_map=5) assert len(gen) == 10 class TestMixedDataset: def test_mix_ratio(self, tmp_path): """Mixed dataset produces approximately correct ratio.""" from nobrainer.augmentation.synthseg import SynthSegGenerator from nobrainer.processing.dataset import MixedDataset path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=50, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=False, noise_std=0, bias_field_std=0, ) # Create a simple "real" dataset real = gen # reuse generator as real for simplicity mixed = MixedDataset(real, gen, ratio=0.5) assert len(mixed) == 50 # Just verify it returns dicts without error sample = mixed[0] assert "image" in sample or isinstance(sample, dict) def test_dataset_mix_method(self, tmp_path): """Dataset.mix() returns a Dataset with _mixed_dataset set.""" from nobrainer.augmentation.synthseg import SynthSegGenerator from nobrainer.processing.dataset import Dataset path = _make_label_map(tmp_path) gen = SynthSegGenerator( [path], n_samples_per_map=5, elastic_std=0, rotation_range=0, flipping=False, randomize_resolution=False, ) pairs = [(str(tmp_path / "label.nii.gz"), str(tmp_path / "label.nii.gz"))] ds = Dataset.from_files(pairs, block_shape=(16, 16, 16), n_classes=2) mixed = ds.mix(gen, ratio=0.3) assert hasattr(mixed, "_mixed_dataset") ================================================ FILE: nobrainer/tests/unit/test_training.py ================================================ """Unit tests for nobrainer.training.fit().""" from __future__ import annotations import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from nobrainer.training import fit def _make_loader(n=8, spatial=8, n_classes=2, batch_size=2): """Synthetic DataLoader for training tests.""" x = torch.randn(n, 1, spatial, spatial, spatial) y = torch.randint(0, n_classes, (n, spatial, spatial, spatial)) ds = TensorDataset(x, y) return DataLoader(ds, batch_size=batch_size) def _make_model(n_classes=2): """Tiny conv model for testing.""" return nn.Sequential( nn.Conv3d(1, 8, 3, padding=1), nn.ReLU(), nn.Conv3d(8, n_classes, 1), ) class TestFit: def test_returns_correct_keys(self): model = _make_model() loader = _make_loader() result = fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters()), max_epochs=2, ) assert "history" in result assert "checkpoint_path" in result assert len(result["history"]) == 2 def test_loss_decreases(self): torch.manual_seed(42) model = _make_model() loader = _make_loader() result = fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters(), lr=1e-2), max_epochs=10, ) losses = [h["loss"] for h in result["history"]] assert ( losses[-1] < losses[0] ), f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}" def test_checkpoint_created(self, tmp_path): model = _make_model() loader = _make_loader() result = fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters()), max_epochs=2, checkpoint_dir=tmp_path, ) assert result["checkpoint_path"] is not None assert (tmp_path / "best_model.pth").exists() assert (tmp_path / "croissant.json").exists() def test_checkpoint_croissant_content(self, tmp_path): """Checkpoint croissant.json contains provenance metadata.""" import json model = _make_model() loader = _make_loader() fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters()), max_epochs=2, checkpoint_dir=tmp_path, ) data = json.loads((tmp_path / "croissant.json").read_text()) prov = data["nobrainer:provenance"] assert prov["epochs_trained"] > 0 assert prov["model_architecture"] == "Sequential" assert prov["loss_function"] == "CrossEntropyLoss" assert "optimizer" in prov def test_epochs_completed(self): model = _make_model() loader = _make_loader() result = fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters()), max_epochs=3, ) assert len(result["history"]) == 3 def test_dict_batch_format(self): """fit() works with dict-style batches (from MONAI DataLoader).""" x = torch.randn(4, 1, 8, 8, 8) y = torch.randint(0, 2, (4, 8, 8, 8)) class DictDataset(torch.utils.data.Dataset): def __len__(self): return 4 def __getitem__(self, idx): return {"image": x[idx], "label": y[idx]} loader = DataLoader(DictDataset(), batch_size=2) model = _make_model() result = fit( model, loader, nn.CrossEntropyLoss(), torch.optim.Adam(model.parameters()), max_epochs=1, ) assert len(result["history"]) == 1 ================================================ FILE: nobrainer/tests/unit/test_training_convergence.py ================================================ """CPU training-convergence smoke tests (US1 acceptance scenario 3). Verifies that each core segmentation model's training loss at epoch 5 is lower than at epoch 1 when overfitting a fixed batch on CPU. Scope: tests nobrainer.models + nobrainer.losses integration; does NOT require GPU or real data. """ from __future__ import annotations import torch from nobrainer.losses import dice from nobrainer.models.highresnet import highresnet from nobrainer.models.meshnet import meshnet from nobrainer.models.segmentation import attention_unet, unet, vnet # Shared synthetic batch: batch_size=2 (satisfies BatchNorm), 32^3 spatial. # Fixed all-ones label is easy to overfit, keeping the test deterministic. _SPATIAL = 32 _N_EPOCHS = 5 _LR = 1e-2 def _run_epochs(model: torch.nn.Module, seed: int = 42) -> list[float]: """Train *model* for _N_EPOCHS and return per-epoch loss values.""" torch.manual_seed(seed) model.train() x = torch.randn(2, 1, _SPATIAL, _SPATIAL, _SPATIAL) y = torch.ones(2, 1, _SPATIAL, _SPATIAL, _SPATIAL) loss_fn = dice() opt = torch.optim.Adam(model.parameters(), lr=_LR) losses = [] for _ in range(_N_EPOCHS): opt.zero_grad() pred = model(x) loss = loss_fn(pred, y) loss.backward() opt.step() losses.append(loss.item()) return losses class TestTrainingConvergence: """US1 scenario 3: loss at epoch 5 < loss at epoch 1 for all core models.""" def test_unet_loss_decreases(self): losses = _run_epochs(unet(n_classes=1)) assert ( losses[-1] < losses[0] ), f"UNet loss did not decrease: epoch1={losses[0]:.4f}, epoch5={losses[-1]:.4f}" def test_vnet_loss_decreases(self): losses = _run_epochs(vnet(n_classes=1)) assert ( losses[-1] < losses[0] ), f"VNet loss did not decrease: epoch1={losses[0]:.4f}, epoch5={losses[-1]:.4f}" def test_attention_unet_loss_decreases(self): losses = _run_epochs(attention_unet(n_classes=1)) assert ( losses[-1] < losses[0] ), f"AttentionUNet loss did not decrease: epoch1={losses[0]:.4f}, epoch5={losses[-1]:.4f}" def test_meshnet_loss_decreases(self): losses = _run_epochs(meshnet(n_classes=1)) assert ( losses[-1] < losses[0] ), f"MeshNet loss did not decrease: epoch1={losses[0]:.4f}, epoch5={losses[-1]:.4f}" def test_highresnet_loss_decreases(self): losses = _run_epochs(highresnet(n_classes=1)) assert ( losses[-1] < losses[0] ), f"HighResNet loss did not decrease: epoch1={losses[0]:.4f}, epoch5={losses[-1]:.4f}" ================================================ FILE: nobrainer/tests/unit/test_transform_pipeline.py ================================================ """Unit tests for TrainableCompose and Augmentation tagging.""" from __future__ import annotations from nobrainer.augmentation.transforms import Augmentation, TrainableCompose def _identity(data): """Preprocessing transform that passes data through.""" data["preprocess_count"] = data.get("preprocess_count", 0) + 1 return data def _augment(data): """Augmentation transform that modifies data.""" data["augment_count"] = data.get("augment_count", 0) + 1 return data class TestAugmentation: def test_wraps_transform(self): aug = Augmentation(_augment) assert aug.is_augmentation is True result = aug({"x": 1}) assert result["augment_count"] == 1 def test_repr(self): aug = Augmentation(_augment) assert "Augmentation" in repr(aug) class TestTrainableCompose: def test_train_mode_runs_all(self): pipeline = TrainableCompose([_identity, Augmentation(_augment)]) result = pipeline({"x": 1}, mode="train") assert result["preprocess_count"] == 1 assert result["augment_count"] == 1 def test_predict_mode_skips_augmentation(self): pipeline = TrainableCompose([_identity, Augmentation(_augment)]) result = pipeline({"x": 1}, mode="predict") assert result["preprocess_count"] == 1 assert "augment_count" not in result def test_default_mode_is_train(self): pipeline = TrainableCompose([_identity, Augmentation(_augment)]) result = pipeline({"x": 1}) assert result["augment_count"] == 1 def test_mode_setter(self): pipeline = TrainableCompose([_identity, Augmentation(_augment)]) pipeline.mode = "predict" result = pipeline({"x": 1}) assert "augment_count" not in result def test_multiple_augmentations_skipped(self): pipeline = TrainableCompose( [ _identity, Augmentation(_augment), _identity, Augmentation(_augment), ] ) result = pipeline({"x": 1}, mode="predict") assert result["preprocess_count"] == 2 assert "augment_count" not in result def test_train_mode_runs_multiple_augmentations(self): pipeline = TrainableCompose( [ _identity, Augmentation(_augment), _identity, Augmentation(_augment), ] ) result = pipeline({"x": 1}, mode="train") assert result["preprocess_count"] == 2 assert result["augment_count"] == 2 def test_empty_pipeline(self): pipeline = TrainableCompose([]) result = pipeline({"x": 1}, mode="train") assert result == {"x": 1} class TestAugmentationProfiles: def test_none_returns_empty(self): from nobrainer.augmentation.profiles import get_augmentation_profile transforms = get_augmentation_profile("none") assert transforms == [] def test_standard_returns_augmentations(self): from nobrainer.augmentation.profiles import get_augmentation_profile transforms = get_augmentation_profile("standard") assert len(transforms) > 0 assert all(getattr(t, "is_augmentation", False) for t in transforms) def test_all_profiles_valid(self): from nobrainer.augmentation.profiles import get_augmentation_profile for name in ("none", "light", "standard", "heavy"): transforms = get_augmentation_profile(name) assert isinstance(transforms, list) def test_unknown_profile_raises(self): import pytest from nobrainer.augmentation.profiles import get_augmentation_profile with pytest.raises(ValueError, match="Unknown augmentation profile"): get_augmentation_profile("extreme") ================================================ FILE: nobrainer/tests/unit/test_vwn_layers.py ================================================ """Unit tests for VWN layers and KWYKMeshNet.""" from __future__ import annotations import torch from nobrainer.models.bayesian.vwn_layers import ConcreteDropout3d, FFGConv3d class TestFFGConv3d: def test_output_shape(self): layer = FFGConv3d(1, 4, kernel_size=3, padding=1) x = torch.randn(2, 1, 8, 8, 8) out = layer(x, mc=True) assert out.shape == (2, 4, 8, 8, 8) def test_deterministic_mode(self): layer = FFGConv3d(1, 4, kernel_size=3, padding=1) x = torch.randn(2, 1, 8, 8, 8) layer.eval() out1 = layer(x, mc=False) out2 = layer(x, mc=False) assert torch.allclose(out1, out2) def test_stochastic_mode_varies(self): layer = FFGConv3d(1, 4, kernel_size=3, padding=1) x = torch.randn(2, 1, 8, 8, 8) out1 = layer(x, mc=True) out2 = layer(x, mc=True) # Outputs should differ due to stochastic sampling assert not torch.allclose(out1, out2) def test_kl_populated_after_mc(self): layer = FFGConv3d(1, 4, kernel_size=3, padding=1) x = torch.randn(2, 1, 8, 8, 8) layer(x, mc=True) assert layer.kl.item() > 0 def test_kernel_m_shape(self): layer = FFGConv3d(1, 4, kernel_size=3, padding=1) assert layer.kernel_m.shape == (4, 1, 3, 3, 3) def test_no_bias(self): layer = FFGConv3d(1, 4, kernel_size=3, padding=1, bias=False) assert layer.bias_m is None x = torch.randn(2, 1, 8, 8, 8) out = layer(x, mc=True) assert out.shape == (2, 4, 8, 8, 8) def test_sigma_positive(self): layer = FFGConv3d(1, 4, kernel_size=3, padding=1) assert (layer.weight_sigma >= 0).all() class TestConcreteDropout3d: def test_output_shape(self): cd = ConcreteDropout3d(4) x = torch.randn(2, 4, 8, 8, 8) out = cd(x, mc=True) assert out.shape == x.shape def test_deterministic_scales(self): cd = ConcreteDropout3d(4) x = torch.ones(2, 4, 8, 8, 8) out = cd(x, mc=False) # In deterministic mode, output = x * p p = cd.p.view(1, -1, 1, 1, 1) expected = x * p assert torch.allclose(out, expected) def test_p_in_range(self): cd = ConcreteDropout3d(4) assert (cd.p >= 0.05).all() assert (cd.p <= 0.95).all() def test_regularization_positive(self): cd = ConcreteDropout3d(4) reg = cd.regularization() assert reg.item() > 0 class TestKWYKMeshNet: def test_bernoulli_variant(self): from nobrainer.models.bayesian.kwyk_meshnet import KWYKMeshNet model = KWYKMeshNet( n_classes=2, filters=8, receptive_field=37, dropout_type="bernoulli", ) x = torch.randn(1, 1, 16, 16, 16) out = model(x, mc=True) assert out.shape == (1, 2, 16, 16, 16) def test_concrete_variant(self): from nobrainer.models.bayesian.kwyk_meshnet import KWYKMeshNet model = KWYKMeshNet( n_classes=2, filters=8, receptive_field=37, dropout_type="concrete", ) x = torch.randn(1, 1, 16, 16, 16) out = model(x, mc=True) assert out.shape == (1, 2, 16, 16, 16) def test_kl_divergence(self): from nobrainer.models.bayesian.kwyk_meshnet import KWYKMeshNet model = KWYKMeshNet(n_classes=2, filters=8, receptive_field=37) x = torch.randn(1, 1, 16, 16, 16) model(x, mc=True) kl = model.kl_divergence() assert torch.isfinite(kl) def test_concrete_regularization(self): from nobrainer.models.bayesian.kwyk_meshnet import KWYKMeshNet model = KWYKMeshNet( n_classes=2, filters=8, receptive_field=37, dropout_type="concrete", ) reg = model.concrete_regularization() assert reg.item() > 0 def test_deterministic_forward(self): from nobrainer.models.bayesian.kwyk_meshnet import KWYKMeshNet model = KWYKMeshNet(n_classes=2, filters=8, receptive_field=37) x = torch.randn(1, 1, 16, 16, 16) model.eval() out1 = model(x, mc=False) out2 = model(x, mc=False) assert torch.allclose(out1, out2) def test_factory_function(self): from nobrainer.models import get model = get("kwyk_meshnet")(n_classes=2, filters=8, receptive_field=37) x = torch.randn(1, 1, 16, 16, 16) out = model(x) assert out.shape == (1, 2, 16, 16, 16) ================================================ FILE: nobrainer/tests/unit/test_zarr_dataset.py ================================================ """Unit tests for ZarrDataset and get_dataset() Zarr routing.""" from __future__ import annotations import nibabel as nib import numpy as np import pytest import torch zarr = pytest.importorskip("zarr", reason="zarr not installed") from nobrainer.dataset import ZarrDataset, _is_zarr_path # noqa: E402 from nobrainer.io import nifti_to_zarr # noqa: E402 def _make_zarr_pair(tmp_path, shape=(32, 32, 32)): """Create a synthetic NIfTI → Zarr pair (image + label).""" img_data = np.random.rand(*shape).astype(np.float32) lbl_data = (np.random.rand(*shape) > 0.5).astype(np.float32) img_nii = tmp_path / "img.nii.gz" lbl_nii = tmp_path / "lbl.nii.gz" nib.save(nib.Nifti1Image(img_data, np.eye(4)), str(img_nii)) nib.save(nib.Nifti1Image(lbl_data, np.eye(4)), str(lbl_nii)) img_zarr = nifti_to_zarr(img_nii, tmp_path / "img.zarr") lbl_zarr = nifti_to_zarr(lbl_nii, tmp_path / "lbl.zarr") return img_zarr, lbl_zarr, img_data, lbl_data class TestIsZarrPath: def test_zarr_extension(self): assert _is_zarr_path("data/brain.zarr") assert _is_zarr_path("data/brain.zarr/") def test_non_zarr(self): assert not _is_zarr_path("data/brain.nii.gz") assert not _is_zarr_path("data/brain.h5") class TestZarrDataset: def test_returns_dict_with_image(self, tmp_path): img_zarr, _, _, _ = _make_zarr_pair(tmp_path) ds = ZarrDataset([{"image": str(img_zarr)}]) item = ds[0] assert "image" in item assert isinstance(item["image"], torch.Tensor) def test_image_shape_has_channel(self, tmp_path): img_zarr, _, img_data, _ = _make_zarr_pair(tmp_path) ds = ZarrDataset([{"image": str(img_zarr)}]) item = ds[0] # Should have channel dim: (1, D, H, W) assert item["image"].shape == (1, *img_data.shape) def test_returns_label_when_provided(self, tmp_path): img_zarr, lbl_zarr, _, _ = _make_zarr_pair(tmp_path) ds = ZarrDataset([{"image": str(img_zarr), "label": str(lbl_zarr)}]) item = ds[0] assert "label" in item assert isinstance(item["label"], torch.Tensor) def test_batch_from_dataloader(self, tmp_path): img_zarr, lbl_zarr, _, _ = _make_zarr_pair(tmp_path) data = [{"image": str(img_zarr), "label": str(lbl_zarr)}] ds = ZarrDataset(data) loader = torch.utils.data.DataLoader(ds, batch_size=1) batch = next(iter(loader)) assert batch["image"].shape[0] == 1 # batch dim assert batch["image"].ndim == 5 # (B, C, D, H, W) def test_multi_resolution_level(self, tmp_path): """Loading at level 1 gives downsampled shape.""" img_data = np.random.rand(64, 64, 64).astype(np.float32) nii_path = tmp_path / "big.nii.gz" nib.save(nib.Nifti1Image(img_data, np.eye(4)), str(nii_path)) zarr_path = nifti_to_zarr(nii_path, tmp_path / "big.zarr", levels=2) ds = ZarrDataset([{"image": str(zarr_path)}], zarr_level=1) item = ds[0] # Level 1 is 2x downsampled: (1, 32, 32, 32) assert item["image"].shape == (1, 32, 32, 32) ================================================ FILE: nobrainer/tests/unit/test_zarr_store.py ================================================ """Unit tests for nobrainer.datasets.zarr_store.""" from __future__ import annotations import json import nibabel as nib import numpy as np import pytest def _make_nifti_pair(tmp_path, idx, shape=(32, 32, 32)): """Create a NIfTI image + label pair.""" img_data = np.random.randn(*shape).astype(np.float32) lbl_data = np.random.randint(0, 5, shape, dtype=np.int32) affine = np.eye(4) img_path = tmp_path / f"sub-{idx:02d}_image.nii.gz" lbl_path = tmp_path / f"sub-{idx:02d}_label.nii.gz" nib.save(nib.Nifti1Image(img_data, affine), str(img_path)) nib.save(nib.Nifti1Image(lbl_data, affine), str(lbl_path)) return str(img_path), str(lbl_path) class TestCreateZarrStore: def test_creates_store(self, tmp_path): from nobrainer.datasets.zarr_store import create_zarr_store pairs = [_make_nifti_pair(tmp_path, i) for i in range(3)] store_path = create_zarr_store( pairs, tmp_path / "test.zarr", conform=False, ) assert store_path.exists() def test_stacked_4d_layout(self, tmp_path): import zarr from nobrainer.datasets.zarr_store import create_zarr_store pairs = [_make_nifti_pair(tmp_path, i) for i in range(3)] store_path = create_zarr_store( pairs, tmp_path / "test.zarr", conform=False, ) store = zarr.open_group(str(store_path), mode="r") assert store["images"].shape == (3, 32, 32, 32) assert store["labels"].shape == (3, 32, 32, 32) assert store["images"].dtype == np.float32 assert store["labels"].dtype == np.int32 def test_metadata_stored(self, tmp_path): from nobrainer.datasets.zarr_store import create_zarr_store, store_info pairs = [_make_nifti_pair(tmp_path, i) for i in range(3)] store_path = create_zarr_store( pairs, tmp_path / "test.zarr", subject_ids=["sub-00", "sub-01", "sub-02"], conform=False, ) info = store_info(store_path) assert info["n_subjects"] == 3 assert info["subject_ids"] == ["sub-00", "sub-01", "sub-02"] assert info["volume_shape"] == [32, 32, 32] assert info["layout"] == "stacked" def test_round_trip_fidelity(self, tmp_path): import zarr from nobrainer.datasets.zarr_store import create_zarr_store pairs = [_make_nifti_pair(tmp_path, i) for i in range(2)] store_path = create_zarr_store(pairs, tmp_path / "test.zarr", conform=False) # Read back and compare original = np.asarray(nib.load(pairs[0][0]).dataobj, dtype=np.float32) store = zarr.open_group(str(store_path), mode="r") stored = np.array(store["images"][0]) assert np.allclose(original, stored, atol=1e-6) def test_partial_io(self, tmp_path): import zarr from nobrainer.datasets.zarr_store import create_zarr_store pairs = [_make_nifti_pair(tmp_path, i) for i in range(5)] store_path = create_zarr_store(pairs, tmp_path / "test.zarr", conform=False) store = zarr.open_group(str(store_path), mode="r") # Read a subregion from subject 2 patch = np.array(store["images"][2, 8:24, 8:24, 8:24]) assert patch.shape == (16, 16, 16) def test_auto_conform(self, tmp_path): import zarr from nobrainer.datasets.zarr_store import create_zarr_store # Create volumes with different shapes — conform should make them uniform # Use shapes where median is 32-divisible for sharding compat pairs = [ _make_nifti_pair(tmp_path, 0, shape=(32, 32, 32)), _make_nifti_pair(tmp_path, 1, shape=(32, 32, 32)), _make_nifti_pair(tmp_path, 2, shape=(64, 64, 64)), ] store_path = create_zarr_store( pairs, tmp_path / "test.zarr", conform=True, ) store = zarr.open_group(str(store_path), mode="r") # All subjects should have same shape assert store["images"].shape[1:] == store["images"].shape[1:] info = dict(store.attrs) assert info["conformed"] is True def test_non_uniform_without_conform_raises(self, tmp_path): from nobrainer.datasets.zarr_store import create_zarr_store pairs = [ _make_nifti_pair(tmp_path, 0, shape=(32, 32, 32)), _make_nifti_pair(tmp_path, 1, shape=(64, 64, 64)), ] with pytest.raises(ValueError, match="Non-uniform shapes"): create_zarr_store(pairs, tmp_path / "test.zarr", conform=False) class TestPartition: def test_create_partition(self, tmp_path): from nobrainer.datasets.zarr_store import create_partition, create_zarr_store pairs = [_make_nifti_pair(tmp_path, i) for i in range(10)] store_path = create_zarr_store(pairs, tmp_path / "test.zarr", conform=False) part_path = create_partition(store_path, ratios=(80, 10, 10)) assert part_path.exists() with open(part_path) as f: data = json.load(f) assert len(data["partitions"]["train"]) == 8 assert len(data["partitions"]["val"]) == 1 assert len(data["partitions"]["test"]) == 1 def test_load_partition(self, tmp_path): from nobrainer.datasets.zarr_store import ( create_partition, create_zarr_store, load_partition, ) pairs = [_make_nifti_pair(tmp_path, i) for i in range(10)] store_path = create_zarr_store(pairs, tmp_path / "test.zarr", conform=False) part_path = create_partition(store_path) partitions = load_partition(part_path) assert "train" in partitions assert "val" in partitions assert "test" in partitions all_ids = partitions["train"] + partitions["val"] + partitions["test"] assert len(set(all_ids)) == 10 # no duplicates def test_different_seeds_produce_different_splits(self, tmp_path): from nobrainer.datasets.zarr_store import ( create_partition, create_zarr_store, load_partition, ) pairs = [_make_nifti_pair(tmp_path, i) for i in range(10)] store_path = create_zarr_store(pairs, tmp_path / "test.zarr", conform=False) p1 = load_partition( create_partition(store_path, seed=1, output_path=tmp_path / "p1.json") ) p2 = load_partition( create_partition(store_path, seed=2, output_path=tmp_path / "p2.json") ) # Different seeds should produce different train sets (with high probability) assert p1["train"] != p2["train"] ================================================ FILE: nobrainer/training.py ================================================ """Training utilities with optional multi-GPU DDP support.""" from __future__ import annotations import logging from pathlib import Path from typing import Any import torch import torch.nn as nn from torch.utils.data import DataLoader logger = logging.getLogger(__name__) def get_device() -> torch.device: """Select the best available device: CUDA > MPS > CPU. .. note:: Also available as :func:`nobrainer.gpu.get_device`. """ from nobrainer.gpu import get_device as _get_device return _get_device() def _run_validation( model: nn.Module, val_loader: DataLoader, criterion: nn.Module, device: torch.device, ) -> dict[str, float]: """Run one validation pass. Returns dict with val_loss, val_acc (overall), and val_bal_acc (balanced accuracy — mean of per-class recall). """ model.eval() total_loss = 0.0 n_correct = 0 n_total = 0 n_batches = 0 # Per-class correct/total for balanced accuracy class_correct: dict[int, int] = {} class_total: dict[int, int] = {} with torch.no_grad(): for batch in val_loader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) elif isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: continue if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() pred = model(images) # Handle model parallel: pred may be on different device if pred.device != labels.device: labels = labels.to(pred.device) total_loss += criterion(pred, labels).item() pred_labels = pred.argmax(1) correct_mask = pred_labels == labels n_correct += correct_mask.sum().item() n_total += labels.numel() n_batches += 1 # Accumulate per-class stats for c in labels.unique().tolist(): mask = labels == c cc = correct_mask[mask].sum().item() ct = mask.sum().item() class_correct[c] = class_correct.get(c, 0) + cc class_total[c] = class_total.get(c, 0) + ct val_loss = total_loss / max(n_batches, 1) val_acc = n_correct / max(n_total, 1) # Balanced accuracy: mean recall per class per_class_recall = [] for c in sorted(class_total.keys()): if class_total[c] > 0: per_class_recall.append(class_correct[c] / class_total[c]) val_bal_acc = sum(per_class_recall) / max(len(per_class_recall), 1) return {"val_loss": val_loss, "val_acc": val_acc, "val_bal_acc": val_bal_acc} def _apply_gradient_checkpointing(model: nn.Module) -> None: """Enable gradient checkpointing on sequential layers to save memory. Wraps each layer's forward in ``torch.utils.checkpoint.checkpoint`` so intermediate activations are recomputed during backward instead of stored. Roughly halves activation memory at ~30% compute cost. """ from torch.utils.checkpoint import checkpoint for name, module in model.named_children(): orig_forward = module.forward def _make_ckpt_forward(fwd): def _ckpt_forward(*args, **kwargs): # checkpoint requires at least one tensor with requires_grad def run(*a): return fwd(*a, **kwargs) tensors = [a for a in args if isinstance(a, torch.Tensor)] if tensors and any(t.requires_grad for t in tensors): return checkpoint(run, *args, use_reentrant=False) return fwd(*args, **kwargs) return _ckpt_forward module.forward = _make_ckpt_forward(orig_forward) logger.info( "Gradient checkpointing enabled on %d modules", len(list(model.children())) ) def _apply_model_parallel(model: nn.Module, gpus: int) -> nn.Module: """Distribute model layers across multiple GPUs (pipeline parallelism). Splits the model's children into ``gpus`` roughly equal groups and places each group on a different GPU. Inserts device-transfer hooks between groups so tensors move between GPUs automatically. Parameters ---------- model : nn.Module Model with sequential children (e.g., KWYKMeshNet). gpus : int Number of GPUs to distribute across. Returns ------- nn.Module The model with layers placed on different GPUs and transfer hooks. """ children = list(model.named_children()) if not children: logger.warning("Model has no children — placing on GPU 0") return model.to("cuda:0") # Split children into roughly equal groups n = len(children) group_size = max(1, (n + gpus - 1) // gpus) groups: list[list[tuple[str, nn.Module]]] = [] for i in range(0, n, group_size): groups.append(children[i : i + group_size]) # Place each group on its GPU device_map: dict[str, int] = {} for gpu_idx, group in enumerate(groups): device = torch.device(f"cuda:{gpu_idx}") for name, module in group: module.to(device) device_map[name] = gpu_idx logger.info( "Model parallel: %d layers across %d GPUs: %s", n, min(gpus, len(groups)), {k: f"cuda:{v}" for k, v in device_map.items()}, ) # Wrap forward to move tensors between devices orig_forward = model.forward def _mp_forward(*args, **kwargs): # Move input to first device first_mod = groups[0][0][1] dev_idx = first_mod.weight.device.index if hasattr(first_mod, "weight") else 0 first_device = torch.device(f"cuda:{dev_idx}") new_args = tuple( a.to(first_device) if isinstance(a, torch.Tensor) else a for a in args ) return orig_forward(*new_args, **kwargs) model.forward = _mp_forward # Add hooks to move activations between GPUs at group boundaries for gpu_idx, group in enumerate(groups): if gpu_idx == 0: continue target_device = torch.device(f"cuda:{gpu_idx}") first_module = group[0][1] def _make_hook(dev): def _hook(module, inputs): return tuple( x.to(dev) if isinstance(x, torch.Tensor) else x for x in inputs ) return _hook first_module.register_forward_pre_hook(_make_hook(target_device)) return model def fit( model: nn.Module, loader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, max_epochs: int = 10, gpus: int = 1, checkpoint_dir: str | Path | None = None, callbacks: list[Any] | None = None, val_loader: DataLoader | None = None, checkpoint_freq: int = 0, gradient_checkpointing: bool = False, model_parallel: bool = False, resume_from: str | Path | None = None, ) -> dict: """Train a model with optional multi-GPU DDP or model parallelism. Parameters ---------- model : nn.Module PyTorch model to train. loader : DataLoader Training data loader. criterion : nn.Module Loss function. optimizer : Optimizer PyTorch optimizer. max_epochs : int Number of training epochs. gpus : int Number of GPUs to use (1 = single GPU/CPU, >1 = DDP or model parallel). checkpoint_dir : path or None Directory for saving checkpoints. None disables checkpointing. callbacks : list or None Optional callback functions called after each epoch with signature ``callback(epoch, logs, model)`` where logs is a dict containing at minimum ``{"loss": float}``. val_loader : DataLoader or None Validation data loader. If provided, validation loss and accuracy are computed each epoch and included in the logs dict. checkpoint_freq : int Save a checkpoint every N epochs (in addition to best model). 0 = only save best model. Checkpoints are saved as ``epoch_NNN.pth`` in checkpoint_dir. gradient_checkpointing : bool If True, trade compute for memory by recomputing activations during backward. Roughly halves activation memory. model_parallel : bool If True and gpus > 1, distribute layers across GPUs (pipeline parallelism) instead of DDP. Useful when a single batch is too large for one GPU. Returns ------- dict ``{"history": [{"epoch": int, "loss": float, ...}, ...], "checkpoint_path": str | None}`` """ device = get_device() # Apply gradient checkpointing if requested if gradient_checkpointing: _apply_gradient_checkpointing(model) # Multi-GPU dispatch if gpus > 1 and torch.cuda.device_count() >= gpus: if model_parallel: # Pipeline parallelism: split layers across GPUs model = _apply_model_parallel(model, gpus) device = torch.device("cuda:0") # input goes to first GPU # Fall through to single-process training loop below else: # Data parallelism: DDP return _fit_ddp( model, loader, criterion, optimizer, max_epochs, gpus, checkpoint_dir, callbacks, val_loader, checkpoint_freq, ) if not model_parallel: model = model.to(device) best_loss = float("inf") ckpt_path = None history: list[dict[str, Any]] = [] # one entry per epoch start_epoch = 0 if checkpoint_dir is not None: checkpoint_dir = Path(checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) # Resume from checkpoint (auto-detect or explicit path) resume_path = None if resume_from is not None: resume_path = Path(resume_from) elif checkpoint_dir is not None and (checkpoint_dir / "checkpoint.pt").exists(): resume_path = checkpoint_dir / "checkpoint.pt" if resume_path is not None and resume_path.exists(): from nobrainer.slurm import load_checkpoint as _load_ckpt ckpt_dir = ( resume_path.parent if resume_path.name == "checkpoint.pt" else checkpoint_dir ) start_epoch, prev_metrics = _load_ckpt(ckpt_dir, model, optimizer) history = prev_metrics.get("history", []) best_loss = min((h["loss"] for h in history), default=float("inf")) logger.info( "Resumed from epoch %d (%d history entries, best_loss=%.4f)", start_epoch, len(history), best_loss, ) for epoch in range(start_epoch, max_epochs): model.train() epoch_loss = 0.0 n_batches = 0 for batch in loader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) elif isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: raise TypeError(f"Unsupported batch type: {type(batch)}") if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer.zero_grad() pred = model(images) if pred.device != labels.device: labels = labels.to(pred.device) loss = criterion(pred, labels) loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 avg_loss = epoch_loss / max(n_batches, 1) # Epoch metrics logs: dict[str, Any] = {"epoch": epoch + 1, "loss": avg_loss} if val_loader is not None: logs.update(_run_validation(model, val_loader, criterion, device)) model.train() history.append(logs) # Best model if avg_loss < best_loss: best_loss = avg_loss if checkpoint_dir is not None: ckpt_path = str(checkpoint_dir / "best_model.pth") torch.save(model.state_dict(), ckpt_path) from nobrainer.processing.croissant import write_checkpoint_croissant write_checkpoint_croissant( checkpoint_dir, model, optimizer, criterion, history ) # Resumable checkpoint (every epoch) if checkpoint_dir is not None: from nobrainer.slurm import save_checkpoint as _save_ckpt _save_ckpt( checkpoint_dir, model, optimizer, epoch + 1, {"history": history} ) # Named checkpoint (for post-hoc Dice eval) if ( checkpoint_dir is not None and checkpoint_freq > 0 and (epoch + 1) % checkpoint_freq == 0 ): epoch_ckpt = checkpoint_dir / f"epoch_{epoch + 1:03d}.pth" torch.save(model.state_dict(), epoch_ckpt) if callbacks: for cb in callbacks: cb(epoch, logs, model) logger.debug( "Epoch %d/%d: %s", epoch + 1, max_epochs, " ".join(f"{k}={v:.4f}" for k, v in logs.items() if isinstance(v, float)), ) return {"history": history, "checkpoint_path": ckpt_path} def _ddp_worker( rank: int, world_size: int, model: nn.Module, train_dataset, val_dataset, batch_size: int, num_workers: int, criterion: nn.Module, optimizer: torch.optim.Optimizer, max_epochs: int, checkpoint_dir: str | Path | None, checkpoint_freq: int, result_dict: dict, ) -> None: """Single DDP worker — module-level function for mp.spawn pickling.""" import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) device = torch.device(f"cuda:{rank}") local_model = model.to(device) ddp_model = DDP(local_model, device_ids=[rank]) sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) ddp_loader = DataLoader( train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True, ) # Validation loader on rank 0 only (no DDP sampler needed) val_loader = None if val_dataset is not None and rank == 0: val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) best_loss = float("inf") ckpt_path = None history: list[dict[str, Any]] = [] if checkpoint_dir is not None and rank == 0: Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) # ExperimentTracker for live metrics (rank 0 only) tracker = None if rank == 0 and checkpoint_dir is not None: from nobrainer.experiment import ExperimentTracker tracker = ExperimentTracker(output_dir=checkpoint_dir) for epoch in range(max_epochs): sampler.set_epoch(epoch) ddp_model.train() epoch_loss = 0.0 n_batches = 0 for batch in ddp_loader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) elif isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: raise TypeError(f"Unsupported batch type: {type(batch)}") if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer.zero_grad() pred = ddp_model(images) loss = criterion(pred, labels) loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 avg_loss = epoch_loss / max(n_batches, 1) logs: dict[str, Any] = {"epoch": epoch + 1, "loss": avg_loss} if rank == 0: if val_loader is not None: logs.update( _run_validation(ddp_model.module, val_loader, criterion, device) ) ddp_model.train() history.append(logs) if avg_loss < best_loss: best_loss = avg_loss if checkpoint_dir is not None: ckpt_path = str(Path(checkpoint_dir) / "best_model.pth") torch.save(ddp_model.module.state_dict(), ckpt_path) from nobrainer.processing.croissant import ( write_checkpoint_croissant, ) write_checkpoint_croissant( checkpoint_dir, ddp_model.module, optimizer, criterion, history, ) # Resumable checkpoint if checkpoint_dir is not None: from nobrainer.slurm import save_checkpoint as _save_ckpt _save_ckpt( checkpoint_dir, ddp_model.module, optimizer, epoch + 1, {"history": history}, ) # Named checkpoint for post-hoc Dice eval if ( checkpoint_dir is not None and checkpoint_freq > 0 and (epoch + 1) % checkpoint_freq == 0 ): torch.save( ddp_model.module.state_dict(), Path(checkpoint_dir) / f"epoch_{epoch + 1:03d}.pth", ) if tracker is not None: tracker.log(logs) logger.info( "Epoch %d/%d: %s", epoch + 1, max_epochs, " ".join( f"{k}={v:.4f}" for k, v in logs.items() if isinstance(v, float) ), ) if rank == 0: result_dict["history"] = history result_dict["checkpoint_path"] = ckpt_path if tracker is not None: tracker.finish() dist.destroy_process_group() def _fit_ddp( model: nn.Module, loader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, max_epochs: int, gpus: int, checkpoint_dir: str | Path | None, callbacks: list[Any] | None, val_loader: DataLoader | None = None, checkpoint_freq: int = 0, ) -> dict: """Multi-GPU training via DistributedDataParallel. Launches ``gpus`` processes via ``mp.spawn``. Validation runs on rank 0 inside the worker — no callbacks needed for val metrics. """ import os import torch.multiprocessing as mp os.environ.setdefault("MASTER_ADDR", "localhost") os.environ.setdefault("MASTER_PORT", "29500") results: dict = mp.Manager().dict() # Extract datasets (picklable) — not DataLoaders (may have closures) train_dataset = loader.dataset val_dataset = val_loader.dataset if val_loader is not None else None mp.spawn( _ddp_worker, args=( gpus, model, train_dataset, val_dataset, loader.batch_size, loader.num_workers, criterion, optimizer, max_epochs, checkpoint_dir, checkpoint_freq, results, ), nprocs=gpus, join=True, ) result = dict(results) if "history" in result: result["history"] = list(result["history"]) return result __all__ = ["fit"] ================================================ FILE: nobrainer/utils.py ================================================ """Utilities for Nobrainer.""" from collections import namedtuple import csv import hashlib import os import tempfile import urllib.request import numpy as np import psutil _cache_dir = os.path.join(tempfile.gettempdir(), "nobrainer-data") def _sha256(path: str) -> str: """Compute SHA-256 hex digest of a file.""" h = hashlib.sha256() with open(path, "rb") as f: for chunk in iter(lambda: f.read(1 << 16), b""): h.update(chunk) return h.hexdigest() def _download_if_needed(url: str, dest: str, expected_hash: str) -> None: """Download *url* to *dest* if the file is missing or hash mismatches.""" if os.path.isfile(dest): if _sha256(dest) == expected_hash: return urllib.request.urlretrieve(url, dest) actual = _sha256(dest) if actual != expected_hash: raise RuntimeError( f"Hash mismatch for {dest}: expected {expected_hash}, got {actual}" ) def get_data(cache_dir=_cache_dir): """Download sample features and labels. The features are T1-weighted MGZ files, and the labels are the corresponding aparc+aseg MGZ files, created with FreeSurfer. This will download 46 megabytes of data. These data can be found at https://datasets.datalad.org/workshops/nih-2017/ds000114/. Parameters ---------- cache_dir: str, directory where to save the data. By default, saves to a temporary directory. Returns ------- List of `(features, labels)`. """ os.makedirs(cache_dir, exist_ok=True) URLHashPair = namedtuple("URLHashPair", "sub x_hash y_hash") hashes = [ URLHashPair( sub="sub-01", x_hash="67d0053f021d1d137bc99715e4e3ebb763364c8ce04311b1032d4253fc149f52", y_hash="7a85b628653f24e2b71cbef6dda86ab24a1743c5f6dbd996bdde258414e780b5", ), URLHashPair( sub="sub-02", x_hash="c0fee669a34bf3b43c8e4aecc88204512ef4e83f2e414640a5abc076b435990c", y_hash="c92357c2571da72d15332b2b4838b94d442d4abd3dbddc4b54202d68f0e19380", ), URLHashPair( sub="sub-03", x_hash="e2bba954e37f5791260f0ec573456e3293bbd40dba139bb1af417eaaeabe63e6", y_hash="e9204f0d50f06a89dd1870911f7ef5e9808e222227799a5384dceeb941ee8f9d", ), URLHashPair( sub="sub-04", x_hash="deec5245a2a5948f7e1053ace8d8a31396b14a96d520c6a52305434e75abe1e8", y_hash="c50e33a3f87aca351414e729b7c25404af364dfe5dd1de5fe380a460cbe9f891", ), URLHashPair( sub="sub-05", x_hash="8a7fe84918f3f80b87903a1e8f7bd20792c0ebc7528fb98513be373258dfd6c0", y_hash="682f52633633551d6fda71ede65aa41e16c332ebf42b4df042bc312200b0337c", ), URLHashPair( sub="sub-06", x_hash="f9a0c40bcd62d7b7e88015867ab5d926009b097ac3235499a541ac9072dd90c8", y_hash="31c842969af9ac178361fa8c13f656a47d27d95357abaf3e7f3521671aa17929", ), URLHashPair( sub="sub-07", x_hash="9de3b7392f5383e7391c5fcd9266d6b7ab6b57bc7ab203cc9ad2a29a2d31a85b", y_hash="b2e48bbfc4185261785643fc8ab066be5f97215b5a9b029ade1ffb12d54d616e", ), URLHashPair( sub="sub-08", x_hash="361098fc69c280970bb0b0d7ea6aba80d383c12e3ccfe5899693bc35b68efbe4", y_hash="0c980ef851b1391f580d91fc87c10d6d30315527cc0749c1010f2b7d5819a009", ), URLHashPair( sub="sub-09", x_hash="1456b35112297df5caacb9d33cb047aa85a3a5b4db3b4b5f9a5c2e189a684e1a", y_hash="696f1e9fef512193b71580292e0edc5835f396d2c8d63909c13668ef7bed433b", ), URLHashPair( sub="sub-10", x_hash="97447f17402e0f9990cd0917f281704893b52a9b61a3241b23a112a0a143d26e", y_hash="97a7947ba1a28963714c9f5c82520d9ef803d005695a0b4109d5a73d7e8a537b", ), ] x_filename = "t1.mgz" y_filename = "aparc+aseg.mgz" url_template = ( "https://datasets.datalad.org/workshops/nih-2017/ds000114/derivatives/" "freesurfer/{sub}/mri/{fname}" ) output = [("features", "labels")] downloads_dir = os.path.join(cache_dir, "datasets") os.makedirs(downloads_dir, exist_ok=True) for h in hashes: x_origin = url_template.format(sub=h.sub, fname=x_filename) y_origin = url_template.format(sub=h.sub, fname=y_filename) x_fname = h.sub + "_" + x_origin.rsplit("/", 1)[-1] y_fname = h.sub + "_" + y_origin.rsplit("/", 1)[-1] x_out = os.path.join(downloads_dir, x_fname) y_out = os.path.join(downloads_dir, y_fname) _download_if_needed(x_origin, x_out, h.x_hash) _download_if_needed(y_origin, y_out, h.y_hash) output.append((x_out, y_out)) csvpath = os.path.join(cache_dir, "filepaths.csv") with open(csvpath, "w", newline="") as f: writer = csv.writer(f) writer.writerows(output) return csvpath class StreamingStats: """Object to calculate statistics on streaming data. Compatible with scalars and n-dimensional arrays. Examples -------- ```python >>> s = StreamingStats() >>> s.update(10).update(20) >>> s.mean() 15.0 ``` ```python >>> import numpy as np >>> a = np.array([[0, 2], [4, 8]]) >>> b = np.array([[2, 4], [8, 16]]) >>> s = StreamingStats() >>> s.update(a).update(b) >>> s.mean() array([[ 1., 3.], [ 6., 12.]]) ``` """ def __init__(self): self._n_samples = 0 self._current_mean = 0.0 self._M = 0.0 def update(self, value): """Update the statistics with the next value. Parameters ---------- value: scalar, array-like Returns ------- Modified instance. """ if self._n_samples == 0: self._current_mean = value else: prev_mean = self._current_mean curr_mean = prev_mean + (value - prev_mean) / (self._n_samples + 1) _M = self._M + (prev_mean - value) * (curr_mean - value) # Set the instance attributes after computation in case there are # errors during computation. self._current_mean = curr_mean self._M = _M self._n_samples += 1 return self def mean(self): """Return current mean of streaming data.""" return self._current_mean def var(self): """Return current variance of streaming data.""" return self._M / self._n_samples def std(self): """Return current standard deviation of streaming data.""" return self.var() ** 0.5 def entropy(self): """Return current entropy of streaming data.""" eps = 1e-07 mult = np.multiply(np.log(self.mean() + eps), self.mean()) return -mult # return -np.sum(mult, axis=axis) def get_num_parallel(): # Get number of processes allocated to the current process. # Note the difference from `os.cpu_count()`. try: num_parallel_calls = len(psutil.Process().cpu_affinity()) except AttributeError: num_parallel_calls = psutil.cpu_count() return num_parallel_calls ================================================ FILE: nobrainer/validation.py ================================================ #!/usr/bin/env python3 from pathlib import Path import nibabel as nib import numpy as np from .io import read_mapping, read_volume from .metrics import dice as dice_numpy from .prediction import predict as _predict from .volume import normalize_numpy, replace DT_X = "float32" def validate_from_filepath( filepath, predictor, block_shape, n_classes, mapping_y, return_variance=False, return_entropy=False, return_array_from_images=False, n_samples=1, normalizer=normalize_numpy, batch_size=4, ): """Computes dice for a prediction compared to a ground truth image. Args: filepath: tuple, tuple of paths to existing neuroimaging volume (index 0) and ground truth (index 1). predictor: TensorFlow Predictor object, predictor from previously trained model. n_classes: int, number of classifications the model is trained to output. mapping_y: path-like, path to csv mapping file per command line argument. block_shape: tuple of len 3, shape of blocks on which to predict. return_variance: Boolean. If set True, it returns the running population variance along with mean. Note, if the n_samples is smaller or equal to 1, the variance will not be returned; instead it will return None return_entropy: Boolean. If set True, it returns the running entropy. along with mean. return_array_from_images: Boolean. If set True and the given input is either image, filepath, or filepaths, it will return arrays of [mean, variance, entropy] instead of images of them. Also, if the input is array, it will simply return array, whether or not this flag is True or False. n_samples: The number of sampling. If set as 1, it will just return the single prediction value. normalizer: callable, function that accepts an ndarray and returns an ndarray. Called before separating volume into blocks. batch_size: int, number of sub-volumes per batch for prediction. dtype: str or dtype object, dtype of features. Returns: `nibabel.spatialimages.SpatialImage` or arrays of predictions of mean, variance(optional), and entropy (optional). """ if not Path(filepath[0]).is_file(): raise FileNotFoundError("could not find file {}".format(filepath[0])) img = nib.load(filepath[0]) y = read_volume(filepath[1], dtype=np.int32) outputs = _predict( inputs=img, predictor=predictor, block_shape=block_shape, return_variance=return_variance, return_entropy=return_entropy, return_array_from_images=return_array_from_images, n_samples=n_samples, normalizer=normalizer, batch_size=batch_size, ) prediction_image = outputs[0].get_data() y = replace(y, read_mapping(mapping_y)) dice = get_dice_for_images(prediction_image, y, n_classes) return outputs, dice def get_dice_for_images(pred, gt, n_classes): """Computes dice for a prediction compared to a ground truth image. Args: pred: nibabel.spatialimages.SpatialImage, a predicted image. gt: nibabel.spatialimages.SpatialImage, a ground-truth image. Returns: `nibabel.spatialimages.SpatialImage`. """ dice = np.zeros(n_classes) for i in range(n_classes): u = np.equal(pred, i) v = np.equal(gt, i) dice[i] = dice_numpy(u, v) return dice def validate_from_filepaths( filepaths, predictor, block_shape, n_classes, mapping_y, output_path, return_variance=False, return_entropy=False, return_array_from_images=False, n_samples=1, normalizer=normalize_numpy, batch_size=4, dtype=DT_X, ): """Yield predictions from filepaths using a SavedModel. Args: test_csv: list, neuroimaging volume filepaths on which to predict. n_classes: int, number of classifications the model is trained to output. mapping_y: path-like, path to csv mapping file per command line argument. block_shape: tuple of len 3, shape of blocks on which to predict. predictor: TensorFlow Predictor object, predictor from previously trained model. block_shape: tuple of len 3, shape of blocks on which to predict. normalizer: callable, function that accepts an ndarray and returns an ndarray. Called before separating volume into blocks. batch_size: int, number of sub-volumes per batch for prediction. dtype: str or dtype object, dtype of features. Returns: None """ for filepath in filepaths: outputs, dice = validate_from_filepath( filepath=filepath, predictor=predictor, n_classes=n_classes, mapping_y=mapping_y, block_shape=block_shape, return_variance=return_variance, return_entropy=return_entropy, return_array_from_images=return_array_from_images, n_samples=n_samples, normalizer=normalizer, batch_size=batch_size, dtype=dtype, ) outpath = Path(filepath[0]) output_path = Path(output_path) suffixes = "".join(s for s in outpath.suffixes) mean_path = output_path / (outpath.stem + "_mean" + suffixes) variance_path = output_path / (outpath.stem + "_variance" + suffixes) entropy_path = output_path / (outpath.stem + "_entropy" + suffixes) dice_path = output_path / (outpath.stem + "_dice.npy") # if mean_path.is_file() or variance_path.is_file() or entropy_path.is_file(): # raise Exception(str(mean_path) + " or " + str(variance_path) + # " or " + str(entropy_path) + " already exists.") nib.save(outputs[0], mean_path.as_posix()) # fix if not return_array_from_images: include_variance = (n_samples > 1) and (return_variance) include_entropy = (n_samples > 1) and (return_entropy) if include_variance and return_entropy: nib.save(outputs[1], str(variance_path)) nib.save(outputs[2], str(entropy_path)) elif include_variance: nib.save(outputs[1], str(variance_path)) elif include_entropy: nib.save(outputs[1], str(entropy_path)) print(filepath[0]) print("Dice: " + str(np.mean(dice))) np.save(dice_path, dice) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] name = "nobrainer" dynamic = ["version"] description = "A deep learning framework for 3D brain image processing." readme = "README.md" license = "Apache-2.0" requires-python = ">= 3.12" authors = [ { name = "Nobrainer Developers", email = "jakub.kaczmarzyk@gmail.com" }, ] maintainers = [ { name = "Satrajit Ghosh", email = "satrajit.ghosh@gmail.com" }, ] classifiers = [ "Development Status :: 4 - Beta", "Environment :: Console", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Healthcare Industry", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development", "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ "click", "einops >= 0.7", "fsspec", "h5py >= 3.9", "joblib", "monai >= 1.3", "nibabel >= 5.0", "numpy", "psutil", "scikit-image", "torch >= 2.0", ] [project.urls] Homepage = "https://github.com/neuronets/nobrainer" Documentation = "https://neuronets.dev/nobrainer-book/" "Source Code" = "https://github.com/neuronets/nobrainer" "Bug Tracker" = "https://github.com/neuronets/helpdesk/issues" [project.scripts] nobrainer = "nobrainer.cli.main:cli" [project.optional-dependencies] bayesian = ["pyro-ppl >= 1.9"] generative = ["pytorch-lightning >= 2.0"] lightning = ["pytorch-lightning >= 2.0"] zarr = ["zarr >= 3.0", "nifti-zarr", "scipy >= 1.11"] croissant = ["mlcroissant"] versioning = ["datalad >= 0.19"] tfrecord = ["tfrecord >= 1.14"] dev = ["pre-commit", "pytest", "pytest-cov", "scipy"] all = [ "nobrainer[bayesian,generative,zarr,croissant,versioning,tfrecord,dev]", ] [tool.hatch.version] source = "vcs" fallback-version = "0.0.0.dev0" [tool.hatch.build.hooks.vcs] version-file = "nobrainer/_version.py" [tool.hatch.build.targets.wheel] packages = ["nobrainer"] [tool.black] exclude = '\.eggs|\.git|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist|_version\.py|versioneer\.py' [tool.isort] profile = "black" force_sort_within_sections = true reverse_relative = true sort_relative_in_force_sorted_sections = true known_first_party = ["nobrainer"] [tool.pytest.ini_options] testpaths = ["nobrainer/tests"] markers = [ "gpu: marks tests that require a CUDA-capable GPU (deselect with '-m not gpu')", ] [tool.coverage.run] branch = true omit = ["nobrainer/_version.py", "*/tests*"] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "raise NotImplementedError", "if __name__ == .__main__.", ] ignore_errors = true [tool.codespell] skip = "nobrainer/_version.py,versioneer.py" ignore-words-list = "nd" ================================================ FILE: scripts/kwyk_reproduction/01_assemble_dataset.py ================================================ #!/usr/bin/env python """Assemble training dataset from OpenNeuro fmriprep derivatives. Uses :mod:`nobrainer.datasets.openneuro` to install datasets via DataLad and discover paired (T1w, aparc+aseg) files per subject. Usage: python 01_assemble_dataset.py --datasets ds000114 --output-csv manifest.csv python 01_assemble_dataset.py --datasets ds000114 ds000228 ds002609 \ --output-csv manifest.csv --label-mapping binary --split 80 10 10 """ from __future__ import annotations import argparse import logging from pathlib import Path logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", ) log = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser(description="Assemble dataset from OpenNeuro") parser.add_argument( "--datasets", nargs="+", default=["ds000114"], help="OpenNeuro dataset IDs", ) parser.add_argument("--output-dir", default="data", help="Output directory") parser.add_argument( "--output-csv", default="manifest.csv", help="Output manifest CSV" ) parser.add_argument( "--label-mapping", default="binary", help="Label mapping: binary, 6-class, 50-class, 115-class", ) parser.add_argument( "--split", nargs=3, type=int, default=[80, 10, 10], help="Train/val/test split percentages", ) parser.add_argument("--conform", action="store_true", help="Resample to 256³ @ 1mm") args = parser.parse_args() from nobrainer.datasets.openneuro import ( find_subject_pairs, install_derivatives, write_manifest, ) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) all_pairs = [] for ds_id in args.datasets: ds_dir = install_derivatives(ds_id, output_dir) pairs = find_subject_pairs(ds_dir) for p in pairs: p["dataset_id"] = ds_id all_pairs.extend(pairs) if not all_pairs: log.error("No subject pairs found. Check dataset IDs and network access.") raise SystemExit(1) # Optionally conform volumes if args.conform: import nibabel as nib from nibabel.processing import conform for row in all_pairs: img = nib.load(row["t1w_path"]) if img.shape[:3] != (256, 256, 256): log.info("Conforming %s", Path(row["t1w_path"]).name) conformed = conform( img, out_shape=(256, 256, 256), voxel_size=(1.0, 1.0, 1.0) ) nib.save(conformed, row["t1w_path"]) write_manifest(all_pairs, args.output_csv, tuple(args.split)) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/02_train_meshnet.py ================================================ #!/usr/bin/env python """Train a deterministic MeshNet for brain extraction / parcellation. Usage: python 02_train_meshnet.py --manifest manifest.csv --config config.yaml python 02_train_meshnet.py --manifest manifest.csv --config config.yaml \ --output-dir checkpoints/meshnet --epochs 100 """ from __future__ import annotations import argparse import csv from pathlib import Path import time import matplotlib.pyplot as plt import numpy as np import torch from utils import load_config, save_figure, setup_logging log = setup_logging(__name__) def parse_args() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Train deterministic MeshNet for brain segmentation", ) parser.add_argument( "--manifest", type=str, required=True, help="Path to the dataset manifest CSV (output of 01_assemble_dataset.py)", ) parser.add_argument( "--config", type=str, default="config.yaml", help="Path to YAML configuration file (default: config.yaml)", ) parser.add_argument( "--output-dir", type=str, default="checkpoints/meshnet", help="Directory for saving model checkpoints and figures", ) parser.add_argument( "--epochs", type=int, default=None, help="Override number of training epochs from config", ) parser.add_argument( "--resume", type=str, default=None, help="Path to checkpoint .pth file to resume from", ) return parser.parse_args() def load_manifest(manifest_path: str, split: str) -> list[tuple[str, str]]: """Load manifest CSV and return (image, label) pairs for the given split.""" pairs = [] with open(manifest_path) as f: reader = csv.DictReader(f) for row in reader: if row["split"] == split: pairs.append((row["t1w_path"], row["label_path"])) return pairs def evaluate_val_dice( seg, val_pairs: list[tuple[str, str]], block_shape: tuple[int, int, int], label_mapping: str | None, n_classes: int = 2, ) -> list[float]: """Compute per-volume mean class Dice on validation set. Returns a list of mean Dice scores (averaged across classes), one per volume. """ import nibabel as nib from nobrainer.prediction import predict from nobrainer.training import get_device # Load remap function for multi-class label mappings remap_fn = None if label_mapping and label_mapping != "binary": from nobrainer.processing.dataset import _load_label_mapping remap_fn = _load_label_mapping(label_mapping) dice_scores = [] device = get_device() model = seg.model_.to(device) model.eval() for img_path, lbl_path in val_pairs: pred_img = predict( inputs=img_path, model=model, block_shape=block_shape, batch_size=128, return_labels=True, ) pred_arr = np.asarray(pred_img.dataobj, dtype=np.int32) gt_arr = np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32) if remap_fn is not None: gt_arr = remap_fn(torch.from_numpy(gt_arr)).numpy() elif label_mapping is None or label_mapping == "binary": gt_arr = (gt_arr > 0).astype(np.int32) pred_arr = (pred_arr > 0).astype(np.int32) # Per-class Dice (skip background class 0) class_dices = [] for c in range(1, n_classes): pred_c = pred_arr == c gt_c = gt_arr == c intersection = (pred_c & gt_c).sum() total = pred_c.sum() + gt_c.sum() class_dices.append(2.0 * intersection / total if total > 0 else 1.0) mean_dice = float(np.mean(class_dices)) dice_scores.append(mean_dice) log.info( " Val volume %s: Dice=%.4f", Path(img_path).name, mean_dice, ) return dice_scores def plot_learning_curve( train_losses: list[float], val_dice_scores: list[float], output_path: Path, ) -> None: """Generate dual y-axis learning curve (loss left, Dice right).""" epochs = list(range(1, len(train_losses) + 1)) fig, ax_loss = plt.subplots(figsize=(10, 6)) ax_dice = ax_loss.twinx() ax_loss.plot(epochs, train_losses, "b-", label="Train Loss") ax_loss.set_xlabel("Epoch") ax_loss.set_ylabel("Loss", color="b") ax_loss.tick_params(axis="y", labelcolor="b") if val_dice_scores: ax_dice.plot(epochs, val_dice_scores, "r-o", label="Val Dice (mean)") ax_dice.set_ylabel("Dice Score", color="r") ax_dice.tick_params(axis="y", labelcolor="r") ax_dice.set_ylim(0.0, 1.0) fig.suptitle("MeshNet Training — Loss & Validation Dice") fig.tight_layout() lines_loss, labels_loss = ax_loss.get_legend_handles_labels() lines_dice, labels_dice = ax_dice.get_legend_handles_labels() ax_loss.legend( lines_loss + lines_dice, labels_loss + labels_dice, loc="center right" ) save_figure(fig, output_path) plt.close(fig) log.info("Learning curve saved to %s", output_path) def main() -> None: """Train deterministic MeshNet and evaluate on validation set.""" args = parse_args() t_start = time.time() # ---- Load config -------------------------------------------------------- config = load_config(args.config) epochs = ( args.epochs if args.epochs is not None else config.get("pretrain_epochs", 50) ) n_classes = config["n_classes"] block_shape = tuple(config["block_shape"]) batch_size = config["batch_size"] lr = config.get("lr", 1e-4) label_mapping = config.get("label_mapping", "binary") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) log.info("Config loaded from %s", args.config) log.info( "Training MeshNet: epochs=%d, n_classes=%d, block_shape=%s, batch_size=%d", epochs, n_classes, block_shape, batch_size, ) # ---- Load manifest and build datasets ----------------------------------- train_pairs = load_manifest(args.manifest, split="train") val_pairs = load_manifest(args.manifest, split="val") log.info("Manifest: %d train, %d val volumes", len(train_pairs), len(val_pairs)) if not train_pairs: log.error("No training volumes found in manifest. Exiting.") return # Auto-scale batch size to GPU memory from nobrainer.gpu import auto_batch_size as _auto_bs from nobrainer.gpu import gpu_count from nobrainer.processing.dataset import Dataset if gpu_count() > 0: from nobrainer.models import get as get_model _tmp_model = get_model("meshnet")( n_classes=n_classes, filters=config.get("filters", 96), receptive_field=config.get("receptive_field", 37), dropout_rate=config.get("dropout_rate", 0.25), ) batch_size = _auto_bs( _tmp_model, block_shape, n_classes=n_classes, target_memory_fraction=0.90, ) del _tmp_model log.info("Auto batch size: %d (target 90%% GPU memory)", batch_size) patches_per_volume = config.get("patches_per_volume", 50) zarr_store = config.get("zarr_store") if zarr_store and Path(zarr_store).exists(): log.info("Using Zarr store: %s", zarr_store) ds_train = ( Dataset.from_zarr( zarr_store, block_shape=block_shape, n_classes=n_classes, partition="train", ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) else: ds_train = ( Dataset.from_files( train_pairs, block_shape=block_shape, n_classes=n_classes ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) n_train = len(ds_train.data) if hasattr(ds_train, "data") else len(train_pairs) log.info( "Training data: %d volumes × %d patches = %d blocks/epoch, batch_size=%d", n_train, patches_per_volume, n_train * patches_per_volume, batch_size, ) # ---- Build validation dataset for per-epoch block-level metrics ---------- ds_val = None if val_pairs: if zarr_store and Path(zarr_store).exists(): ds_val = ( Dataset.from_zarr( zarr_store, block_shape=block_shape, n_classes=n_classes, partition="val", ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) else: ds_val = ( Dataset.from_files( val_pairs, block_shape=block_shape, n_classes=n_classes ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) # ---- Train with Segmentation estimator ---------------------------------- from nobrainer.processing.segmentation import Segmentation model_args = { "n_classes": n_classes, "filters": config.get("filters", 96), "receptive_field": config.get("receptive_field", 37), "dropout_rate": config.get("dropout_rate", 0.25), } log.info("Model args: %s", model_args) seg = Segmentation( base_model="meshnet", model_args=model_args, checkpoint_filepath=str(output_dir), ) val_dice_per_epoch: list[float] = [] val_dice_freq = config.get("val_dice_freq", 5) # Simple logging callback (picklable — no closures) def _log_cb(epoch, logs, model): msg = f"Epoch {epoch + 1}/{epochs}: train_loss={logs['loss']:.6f}" if "val_loss" in logs: msg += f" val_loss={logs['val_loss']:.6f}" if "val_acc" in logs: msg += f" val_acc={logs['val_acc']:.4f}" if "val_bal_acc" in logs: msg += f" bal_acc={logs['val_bal_acc']:.4f}" log.info(msg) seg.fit( dataset_train=ds_train, dataset_validate=ds_val, epochs=epochs, optimizer=torch.optim.Adam, opt_args={"lr": lr}, callbacks=[_log_cb], checkpoint_freq=val_dice_freq, gradient_checkpointing=config.get("gradient_checkpointing", False), model_parallel=config.get("model_parallel", False), resume_from=args.resume, ) history = seg._training_result.get("history", []) if history: last = history[-1] log.info( "Training complete. %s", " ".join(f"{k}={v:.4f}" for k, v in last.items() if isinstance(v, float)), ) else: log.info("Training complete (no history).") # Ensure model is on the right device after DDP from nobrainer.training import get_device seg.model_.to(get_device()) # Evaluate full-volume Dice on each checkpointed epoch if val_pairs: for epoch_idx in range(len(history)): epoch_num = history[epoch_idx].get("epoch", epoch_idx + 1) ckpt_file = output_dir / f"epoch_{epoch_num:03d}.pth" if ckpt_file.exists(): log.info("Evaluating Dice at epoch %d...", epoch_num) seg.model_.load_state_dict( torch.load(ckpt_file, map_location=get_device(), weights_only=True) ) dice_scores = evaluate_val_dice( seg, val_pairs, block_shape, label_mapping, n_classes ) mean_dice = float(np.mean(dice_scores)) if dice_scores else 0.0 history[epoch_idx]["val_dice"] = mean_dice log.info(" Epoch %d Dice: %.4f", epoch_num, mean_dice) train_losses = [h["loss"] for h in history] val_dice_per_epoch = [h.get("val_dice", float("nan")) for h in history] fig_path = output_dir / "learning_curve.png" plot_learning_curve(train_losses, val_dice_per_epoch, fig_path) # ---- Save model with Croissant-ML metadata ------------------------------ seg.save(output_dir) log.info("Model and Croissant-ML metadata saved to %s", output_dir) # ---- Summary ------------------------------------------------------------ elapsed = time.time() - t_start log.info("=" * 60) log.info("MeshNet training complete") log.info(" Output directory : %s", output_dir) log.info(" Epochs : %d", epochs) log.info(" Final train loss : %.6f", train_losses[-1] if train_losses else 0.0) final_dice = [d for d in val_dice_per_epoch if not np.isnan(d)] if final_dice: log.info(" Val Dice (mean) : %.4f", final_dice[-1]) log.info(" Elapsed time : %.1f s", elapsed) log.info("=" * 60) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/03_train_bayesian.py ================================================ #!/usr/bin/env python """Train a Bayesian MeshNet with optional warm-start from deterministic weights. Supports the three kwyk model variants: - bvwn_multi_prior: Spike-and-slab dropout (default) - bayesian_gaussian: Standard Gaussian prior - bwn_multi: MC Bernoulli dropout (deterministic model, dropout at inference) Usage: # Spike-and-slab (original kwyk variant) python 03_train_bayesian.py --manifest manifest.csv --config config.yaml \ --variant bvwn_multi_prior --warmstart checkpoints/meshnet # Standard Gaussian prior python 03_train_bayesian.py --manifest manifest.csv --config config.yaml \ --variant bayesian_gaussian --warmstart checkpoints/meshnet # MC Bernoulli dropout (copies deterministic weights, uses dropout at inference) python 03_train_bayesian.py --manifest manifest.csv --config config.yaml \ --variant bwn_multi --warmstart checkpoints/meshnet # Override epochs python 03_train_bayesian.py --manifest manifest.csv --config config.yaml \ --warmstart checkpoints/meshnet --epochs 100 """ from __future__ import annotations import argparse import csv import os from pathlib import Path import time import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from utils import load_config, save_figure, setup_logging from nobrainer.gpu import auto_batch_size, gpu_count from nobrainer.slurm import SlurmPreemptionHandler, load_checkpoint, save_checkpoint log = setup_logging(__name__) # --------------------------------------------------------------------------- # ELBO loss: CrossEntropy + KL divergence from Bayesian layers # --------------------------------------------------------------------------- class ELBOLoss(nn.Module): """Evidence Lower Bound loss combining CE and KL divergence. Parameters ---------- model : nn.Module Bayesian model whose layers carry ``.kl`` attributes after each forward pass. kl_weight : float Scaling factor for the KL term. ``1.0`` corresponds to the standard variational free-energy; smaller values down-weight the regularisation (cold posterior). """ def __init__( self, model: nn.Module, kl_weight: float = 1.0, class_weights: torch.Tensor | None = None, ) -> None: super().__init__() self.ce = nn.CrossEntropyLoss(weight=class_weights) self.model = model self.kl_weight = kl_weight self._last_kl: float = 0.0 def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: from nobrainer.models.bayesian.utils import accumulate_kl ce_loss = self.ce(pred, target) kl_loss = accumulate_kl(self.model) self._last_kl = kl_loss.item() return ce_loss + self.kl_weight * kl_loss # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Train Bayesian MeshNet with optional warm-start", ) parser.add_argument( "--manifest", type=str, required=True, help="Path to the dataset manifest CSV", ) parser.add_argument( "--config", type=str, default="config.yaml", help="Path to YAML configuration file", ) parser.add_argument( "--output-dir", type=str, default="checkpoints/bayesian", help="Directory for saving model checkpoints and figures", ) parser.add_argument( "--variant", type=str, default="bvwn_multi_prior", choices=["bvwn_multi_prior", "bayesian_gaussian", "bwn_multi"], help=( "Model variant: bvwn_multi_prior (spike-and-slab, default), " "bayesian_gaussian (Gaussian prior), bwn_multi (MC Bernoulli dropout)" ), ) parser.add_argument( "--warmstart", type=str, default=None, help="Path to a trained deterministic MeshNet directory (containing model.pth)", ) parser.add_argument( "--no-warmstart", action="store_true", help="Explicitly disable warm-start (train from scratch)", ) parser.add_argument( "--epochs", type=int, default=None, help="Override number of training epochs from config", ) return parser.parse_args() def load_manifest(manifest_path: str, split: str) -> list[tuple[str, str]]: """Load manifest CSV and return (image, label) pairs for the given split.""" pairs = [] with open(manifest_path) as f: reader = csv.DictReader(f) for row in reader: if row["split"] == split: pairs.append((row["t1w_path"], row["label_path"])) return pairs # --------------------------------------------------------------------------- # Validation with MC inference # --------------------------------------------------------------------------- def evaluate_mc_dice( model: nn.Module, val_pairs: list[tuple[str, str]], block_shape: tuple[int, int, int], n_samples: int, label_mapping: str | None, n_classes: int = 2, ) -> tuple[list[float], list[float]]: """Run MC inference on each validation volume. Returns ------- mean_dices : list[float] Mean class Dice across MC samples for each volume. std_dices : list[float] Std of Dice across MC samples for each volume. """ import nibabel as nib from nobrainer.prediction import predict from nobrainer.training import get_device # Load remap function for multi-class label mappings remap_fn = None if label_mapping and label_mapping != "binary": from nobrainer.processing.dataset import _load_label_mapping remap_fn = _load_label_mapping(label_mapping) mean_dices: list[float] = [] std_dices: list[float] = [] device = get_device() model = model.to(device) for img_path, lbl_path in val_pairs: gt_arr = np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32) if remap_fn is not None: gt_arr = remap_fn(torch.from_numpy(gt_arr)).numpy() elif label_mapping is None or label_mapping == "binary": gt_arr = (gt_arr > 0).astype(np.int32) # Multiple stochastic forward passes sample_dices: list[float] = [] for s in range(n_samples): model.train() pred_img = predict( inputs=img_path, model=model, block_shape=block_shape, batch_size=128, return_labels=True, ) pred_arr = np.asarray(pred_img.dataobj, dtype=np.int32) if label_mapping is None or label_mapping == "binary": pred_arr = (pred_arr > 0).astype(np.int32) # Per-class Dice (skip background) class_dices = [] for c in range(1, n_classes): pred_c = pred_arr == c gt_c = gt_arr == c intersection = (pred_c & gt_c).sum() total = pred_c.sum() + gt_c.sum() class_dices.append(2.0 * intersection / total if total > 0 else 1.0) sample_dices.append(float(np.mean(class_dices))) vol_mean = float(np.mean(sample_dices)) vol_std = float(np.std(sample_dices)) mean_dices.append(vol_mean) std_dices.append(vol_std) log.info( " Val volume %s: MC Dice=%.4f +/- %.4f (%d samples)", Path(img_path).name, vol_mean, vol_std, n_samples, ) return mean_dices, std_dices # --------------------------------------------------------------------------- # Learning curve with uncertainty bands # --------------------------------------------------------------------------- def plot_learning_curve( train_losses: list[float], val_losses: list[float], val_dice_means: list[float], val_dice_stds: list[float], kl_terms: list[float], output_path: Path, ) -> None: """Generate learning curve with uncertainty bands. Left y-axis: train loss, val loss, KL term. Right y-axis: mean MC Dice with +/- std shading. """ epochs = list(range(1, len(train_losses) + 1)) fig, ax_loss = plt.subplots(figsize=(12, 7)) ax_dice = ax_loss.twinx() # Loss curves ax_loss.plot(epochs, train_losses, "b-", label="Train Loss (ELBO)") if val_losses: ax_loss.plot(epochs, val_losses, "b--", alpha=0.7, label="Val Loss") if kl_terms: ax_loss.plot(epochs, kl_terms, "g-.", alpha=0.6, label="KL Term") ax_loss.set_xlabel("Epoch") ax_loss.set_ylabel("Loss / KL", color="b") ax_loss.tick_params(axis="y", labelcolor="b") # Dice with uncertainty bands if val_dice_means: means = np.array(val_dice_means) stds = np.array(val_dice_stds) dice_epochs = list(range(1, len(means) + 1)) ax_dice.plot( dice_epochs, means, "r-o", markersize=3, label="Val MC Dice (mean)" ) ax_dice.fill_between( dice_epochs, np.clip(means - stds, 0, 1), np.clip(means + stds, 0, 1), color="r", alpha=0.15, label="Val MC Dice (+/- std)", ) ax_dice.set_ylabel("Dice Score", color="r") ax_dice.tick_params(axis="y", labelcolor="r") ax_dice.set_ylim(0.0, 1.0) fig.suptitle("Bayesian MeshNet Training — ELBO Loss & MC Dice") fig.tight_layout() lines_loss, labels_loss = ax_loss.get_legend_handles_labels() lines_dice, labels_dice = ax_dice.get_legend_handles_labels() ax_loss.legend( lines_loss + lines_dice, labels_loss + labels_dice, loc="center right", ) save_figure(fig, output_path) plt.close(fig) log.info("Learning curve saved to %s", output_path) # --------------------------------------------------------------------------- # Training loop (lower-level, using nobrainer.training.fit) # --------------------------------------------------------------------------- def train_bayesian( model: nn.Module, train_loader, val_loader, elbo_loss: ELBOLoss, optimizer: torch.optim.Optimizer, epochs: int, val_pairs: list[tuple[str, str]], block_shape: tuple[int, int, int], n_samples: int, label_mapping: str | None, n_classes: int, checkpoint_dir: Path, preemption_handler: SlurmPreemptionHandler | None = None, callbacks: list | None = None, ) -> dict: """Custom training loop for Bayesian MeshNet with ELBO loss. Supports checkpoint/resume for SLURM preemptible jobs. When a preemption signal is received, the loop checkpoints and exits so the job can be requeued. """ from nobrainer.training import get_device device = get_device() model = model.to(device) # -- Resume from checkpoint if available -------------------------------- start_epoch, prev_metrics = load_checkpoint(checkpoint_dir, model, optimizer) # Restore accumulated metrics from prior runs train_losses: list[float] = prev_metrics.get("train_losses", []) val_losses_list: list[float] = prev_metrics.get("val_losses", []) val_dice_means: list[float] = prev_metrics.get("val_dice_means", []) val_dice_stds: list[float] = prev_metrics.get("val_dice_stds", []) kl_terms: list[float] = prev_metrics.get("kl_terms", []) best_loss: float = prev_metrics.get("best_loss", float("inf")) if start_epoch >= epochs: log.info("Already completed %d/%d epochs — nothing to do", start_epoch, epochs) return { "train_losses": train_losses, "val_losses": val_losses_list, "val_dice_means": val_dice_means, "val_dice_stds": val_dice_stds, "kl_terms": kl_terms, "best_loss": best_loss, "epochs_completed": start_epoch, } for epoch in range(start_epoch, epochs): t_epoch = time.time() # -- Train one epoch -------------------------------------------------- model.train() epoch_loss = 0.0 epoch_kl = 0.0 n_batches = 0 for batch in train_loader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) elif isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: raise TypeError(f"Unsupported batch type: {type(batch)}") # Squeeze channel dim from labels if present if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer.zero_grad() # Match original TF: deterministic VWN weights + stochastic dropout # (is_mc_v=False, is_mc_b=True in meshnetbwn.py) try: pred = model(images, mc_vwn=False, mc_dropout=True) except TypeError: pred = model(images) loss = elbo_loss(pred, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_kl += elbo_loss._last_kl n_batches += 1 avg_loss = epoch_loss / max(n_batches, 1) avg_kl = epoch_kl / max(n_batches, 1) train_losses.append(avg_loss) kl_terms.append(avg_kl) # -- Validate --------------------------------------------------------- val_loss = 0.0 if val_loader is not None: model.eval() n_val = 0 with torch.no_grad(): for batch in val_loader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) elif isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: raise TypeError(f"Unsupported batch type: {type(batch)}") if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() try: pred = model(images, mc_vwn=False, mc_dropout=False) except TypeError: pred = model(images) loss = elbo_loss(pred, labels) val_loss += loss.item() n_val += 1 val_loss = val_loss / max(n_val, 1) val_losses_list.append(val_loss) # -- MC Dice evaluation (every 10 epochs or last epoch) --------------- if val_pairs and (epoch == epochs - 1 or (epoch + 1) % 10 == 0): mean_dices, std_dices = evaluate_mc_dice( model, val_pairs, block_shape, n_samples, label_mapping, n_classes ) overall_mean = float(np.mean(mean_dices)) if mean_dices else 0.0 overall_std = float(np.mean(std_dices)) if std_dices else 0.0 val_dice_means.append(overall_mean) val_dice_stds.append(overall_std) else: if val_dice_means: val_dice_means.append(val_dice_means[-1]) val_dice_stds.append(val_dice_stds[-1]) else: val_dice_means.append(float("nan")) val_dice_stds.append(float("nan")) # -- Checkpoint best -------------------------------------------------- if avg_loss < best_loss: best_loss = avg_loss torch.save(model.state_dict(), checkpoint_dir / "best_model.pth") # -- Always save resumable checkpoint --------------------------------- metrics = { "train_losses": train_losses, "val_losses": val_losses_list, "val_dice_means": val_dice_means, "val_dice_stds": val_dice_stds, "kl_terms": kl_terms, "best_loss": best_loss, } save_checkpoint(checkpoint_dir, model, optimizer, epoch, metrics) elapsed = time.time() - t_epoch log.info( "Epoch %d/%d: train_loss=%.6f val_loss=%.6f kl=%.6f " "dice=%.4f (%.1fs)", epoch + 1, epochs, avg_loss, val_loss, avg_kl, val_dice_means[-1] if val_dice_means else 0.0, elapsed, ) # -- Callbacks ----------------------------------------------------------- for cb in callbacks or []: cb(epoch, avg_loss, model) # -- Check for SLURM preemption signal -------------------------------- if preemption_handler and preemption_handler.preempted: log.warning( "Preemption detected after epoch %d — exiting for requeue", epoch + 1, ) break return { "train_losses": train_losses, "val_losses": val_losses_list, "val_dice_means": val_dice_means, "val_dice_stds": val_dice_stds, "kl_terms": kl_terms, "best_loss": best_loss, "epochs_completed": epoch + 1, } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: """Train Bayesian MeshNet with optional warm-start.""" args = parse_args() t_start = time.time() # ---- Load config -------------------------------------------------------- config = load_config(args.config) epochs = ( args.epochs if args.epochs is not None else config.get("bayesian_epochs", 50) ) n_classes = config["n_classes"] block_shape = tuple(config["block_shape"]) batch_size = config["batch_size"] lr = config.get("lr", 1e-4) n_samples = config.get("n_samples", 10) label_mapping = config.get("label_mapping", "binary") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) use_warmstart = args.warmstart is not None and not args.no_warmstart # ---- Load variant config from config.yaml variants section -------------- variant = args.variant variants = config.get("variants", {}) variant_config = variants.get(variant, {}) log.info("Model variant: %s — %s", variant, variant_config.get("description", "")) # ---- Determine KL weight from config ------------------------------------ kl_weight = variant_config.get("kl_weight", config.get("kl_weight", 1.0)) log.info("Config loaded from %s", args.config) log.info( "Training KWYK MeshNet (%s): epochs=%d, n_classes=%d, " "block_shape=%s, kl_weight=%.4f, warmstart=%s", variant, epochs, n_classes, block_shape, kl_weight, use_warmstart, ) # ---- Load manifest and build datasets ----------------------------------- train_pairs = load_manifest(args.manifest, split="train") val_pairs = load_manifest(args.manifest, split="val") log.info("Manifest: %d train, %d val volumes", len(train_pairs), len(val_pairs)) if not train_pairs: log.error("No training volumes found in manifest. Exiting.") return # ---- Build KWYK MeshNet first (needed for auto batch size) ---------------- from nobrainer.models import get as get_model from nobrainer.processing.dataset import Dataset dropout_type = variant_config.get("dropout_type", "bernoulli") model_args = { "n_classes": n_classes, "filters": config.get("filters", 96), "receptive_field": config.get("receptive_field", 37), "dropout_type": dropout_type, "dropout_rate": config.get("dropout_rate", 0.25), "sigma_init": config.get("sigma_init", 1e-4), } # Concrete dropout specific params if dropout_type == "concrete": model_args["concrete_temperature"] = variant_config.get( "concrete_temperature", 0.02 ) model_args["concrete_init_p"] = variant_config.get("concrete_init_p", 0.9) log.info("Model args: %s", model_args) bayesian_model = get_model("kwyk_meshnet")(**model_args) log.info( "KWYK MeshNet (%s, %s) created: %d parameters", variant, dropout_type, sum(p.numel() for p in bayesian_model.parameters()), ) # ---- Auto batch size with training-mode profiling ------------------------- n_gpus = gpu_count() if n_gpus > 0: optimal_per_gpu = auto_batch_size( bayesian_model, block_shape, n_classes=n_classes, target_memory_fraction=0.90, forward_kwargs={"mc_vwn": False, "mc_dropout": True}, ) log.info( "Auto batch size: %d (profiled with mc_vwn=False, mc_dropout=True, " "config batch_size=%d)", optimal_per_gpu, batch_size, ) batch_size = optimal_per_gpu # ---- Build datasets with optimized batch size ---------------------------- # Use streaming mode: extract multiple patches per volume to fill GPU. # Use Zarr store if available, else fall back to NIfTI with streaming patches_per_volume = config.get("patches_per_volume", 50) zarr_store = config.get("zarr_store") if zarr_store and Path(zarr_store).exists(): log.info("Using Zarr store: %s", zarr_store) ds_train = ( Dataset.from_zarr( zarr_store, block_shape=block_shape, n_classes=n_classes, partition="train", ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) else: ds_train = ( Dataset.from_files( train_pairs, block_shape=block_shape, n_classes=n_classes ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) train_loader = ds_train.dataloader n_train = len(ds_train.data) if hasattr(ds_train, "data") else len(train_pairs) log.info( "Training data: %d volumes × %d patches = %d blocks/epoch, batch_size=%d", n_train, patches_per_volume, n_train * patches_per_volume, batch_size, ) ds_val = None val_loader = None if val_pairs: if zarr_store and Path(zarr_store).exists(): ds_val = ( Dataset.from_zarr( zarr_store, block_shape=block_shape, n_classes=n_classes, partition="val", ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) else: ds_val = ( Dataset.from_files( val_pairs, block_shape=block_shape, n_classes=n_classes ) .batch(batch_size) .binarize(label_mapping) .streaming(patches_per_volume=patches_per_volume) ) val_loader = ds_val.dataloader # ---- Optional warm-start ------------------------------------------------ if use_warmstart: warmstart_dir = Path(args.warmstart) det_weights_path = warmstart_dir / "model.pth" if not det_weights_path.exists(): log.error( "Warm-start weights not found at %s. " "Train a deterministic MeshNet first with 02_train_meshnet.py.", det_weights_path, ) return log.info("Loading deterministic weights from %s", det_weights_path) from nobrainer.models.bayesian.warmstart import ( warmstart_kwyk_from_deterministic, ) n_transferred = warmstart_kwyk_from_deterministic( bayesian_model, det_weights_path, get_model, ) log.info("Warm-started %d layers from deterministic model", n_transferred) else: log.info("Training KWYK MeshNet from scratch (no warm-start)") # ---- Class weights (important for 50-class parcellation) ----------------- class_weights = None weight_method = config.get("class_weight_method") if weight_method and weight_method != "null": from nobrainer.losses import compute_class_weights label_paths = [p[1] for p in train_pairs] class_weights = compute_class_weights( label_paths, n_classes, label_mapping=label_mapping, method=weight_method, max_samples=50, ) log.info( "Class weights computed (%s): min=%.3f, max=%.3f, mean=%.3f", weight_method, class_weights.min(), class_weights.max(), class_weights.mean(), ) # Move weights to device from nobrainer.training import get_device class_weights = class_weights.to(get_device()) # ---- ELBO loss and optimiser -------------------------------------------- elbo_loss = ELBOLoss( bayesian_model, kl_weight=kl_weight, class_weights=class_weights ) optimizer = torch.optim.Adam(bayesian_model.parameters(), lr=lr) # ---- SLURM preemption handler (no-op if not on SLURM) ----------------- preemption = None if os.environ.get("SLURM_JOB_ID"): preemption = SlurmPreemptionHandler() # ---- Experiment tracker (local + optional W&B) ------------------------- from nobrainer.experiment import ExperimentTracker tracker = ExperimentTracker( output_dir=output_dir, config={ "variant": variant, "dropout_type": dropout_type, "n_classes": n_classes, "filters": config.get("filters", 96), "block_shape": list(block_shape), "batch_size": batch_size, "lr": lr, "kl_weight": kl_weight, "epochs": epochs, "warmstart": use_warmstart, }, project="kwyk-reproduction", name=variant, tags=[variant, f"{n_classes}-class"], ) # ---- Train -------------------------------------------------------------- result = train_bayesian( model=bayesian_model, train_loader=train_loader, val_loader=val_loader, elbo_loss=elbo_loss, optimizer=optimizer, epochs=epochs, val_pairs=val_pairs, block_shape=block_shape, n_samples=n_samples, label_mapping=label_mapping, n_classes=n_classes, checkpoint_dir=output_dir, preemption_handler=preemption, callbacks=[tracker.callback(variant=variant)], ) # ---- Learning curve with uncertainty bands ------------------------------ fig_path = output_dir / "learning_curve.png" plot_learning_curve( train_losses=result["train_losses"], val_losses=result["val_losses"], val_dice_means=result["val_dice_means"], val_dice_stds=result["val_dice_stds"], kl_terms=result["kl_terms"], output_path=fig_path, ) # ---- Save with Croissant-ML metadata ------------------------------------ # Save final weights torch.save(bayesian_model.state_dict(), output_dir / "model.pth") # Use Segmentation estimator's save for Croissant metadata from nobrainer.processing.segmentation import Segmentation seg = Segmentation( base_model="kwyk_meshnet", model_args=model_args, ) seg.model_ = bayesian_model seg.block_shape_ = block_shape seg.n_classes_ = n_classes seg._optimizer_class = "Adam" seg._optimizer_args = {"lr": lr} seg._loss_name = "ELBOLoss" seg._training_result = { "variant": variant, "dropout_type": dropout_type, "final_loss": result["train_losses"][-1] if result["train_losses"] else 0.0, "best_loss": result["best_loss"], "epochs_completed": result["epochs_completed"], "checkpoint_path": str(output_dir / "best_model.pth"), } seg._dataset = ds_train seg.save(output_dir) log.info("Model and Croissant-ML metadata saved to %s", output_dir) # ---- Summary ------------------------------------------------------------ elapsed = time.time() - t_start final_dice = ( result["val_dice_means"][-1] if result["val_dice_means"] else float("nan") ) log.info("=" * 60) log.info("Bayesian MeshNet training complete (%s)", variant) log.info(" Output directory : %s", output_dir) log.info(" Variant : %s", variant) log.info(" Dropout type : %s", dropout_type) log.info(" Epochs : %d", epochs) log.info(" Warm-start : %s", "yes" if use_warmstart else "no") log.info(" KL weight : %.4f", kl_weight) log.info( " Final train loss : %.6f", result["train_losses"][-1] if result["train_losses"] else 0.0, ) log.info(" Best train loss : %.6f", result["best_loss"]) log.info(" Val MC Dice : %.4f", final_dice) log.info(" MC samples : %d", n_samples) log.info(" Elapsed time : %.1f s", elapsed) log.info("=" * 60) tracker.finish() if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/04_evaluate.py ================================================ #!/usr/bin/env python """Evaluate a trained segmentation model on test volumes. Computes per-class Dice for each volume (matching McClure et al. 2019, Section 2.4.1, Eq. 19), then averages across classes per volume. The reported "class Dice" in Table 3 of the paper is the mean ± std of these per-volume average Dice scores. For Bayesian models, MC inference produces variance and entropy maps (Eq. 20) saved as NIfTI files. Usage: python 04_evaluate.py --model checkpoints/bvwn_multi_prior \ --manifest manifest.csv --split test --n-samples 10 """ from __future__ import annotations import argparse import csv from pathlib import Path import matplotlib.pyplot as plt import nibabel as nib import numpy as np import torch from utils import load_config, save_figure, setup_logging log = setup_logging(__name__) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate segmentation model") parser.add_argument("--model", type=str, required=True) parser.add_argument("--manifest", type=str, required=True) parser.add_argument("--config", type=str, default="config.yaml") parser.add_argument("--split", type=str, default="test") parser.add_argument("--n-samples", type=int, default=10) parser.add_argument("--output-dir", type=str, default="results") return parser.parse_args() def load_manifest(manifest_path: str, split: str) -> list[tuple[str, str]]: pairs = [] with open(manifest_path) as f: for row in csv.DictReader(f): if row["split"] == split: pairs.append((row["t1w_path"], row["label_path"])) return pairs def per_class_dice( pred: np.ndarray, gt: np.ndarray, n_classes: int, ) -> np.ndarray: """Compute Dice coefficient for each class c = 1..n_classes-1. Matches Eq. 19 in McClure et al. (2019): Dice_c = 2*TP_c / (2*TP_c + FN_c + FP_c) Class 0 (background / unknown) is excluded, matching the paper: "averaging across all output voxels not classified as background". Parameters ---------- pred : np.ndarray Integer label predictions. gt : np.ndarray Integer ground truth labels. n_classes : int Total number of classes (including background). Returns ------- np.ndarray Shape ``(n_classes - 1,)`` — Dice for classes 1..n_classes-1. """ dice_scores = np.zeros(n_classes - 1) for c in range(1, n_classes): pred_c = (pred == c).astype(np.float64) gt_c = (gt == c).astype(np.float64) intersection = (pred_c * gt_c).sum() total = pred_c.sum() + gt_c.sum() if total > 0: dice_scores[c - 1] = 2.0 * intersection / total else: # Both empty for this class — perfect agreement dice_scores[c - 1] = 1.0 return dice_scores def compute_entropy(prob_map: np.ndarray) -> np.ndarray: """Compute entropy of softmax probabilities (Eq. 20). H(y|x) = -sum_c p(y_c|x) log p(y_c|x) """ eps = 1e-10 return -(prob_map * np.log(prob_map + eps)).sum(axis=0) def plot_prediction_overlay( t1w_arr: np.ndarray, pred_arr: np.ndarray, gt_arr: np.ndarray, output_path: Path, title: str = "Prediction Overlay", ) -> None: """3-panel figure: T1w, prediction, ground truth (middle axial slice).""" mid = t1w_arr.shape[2] // 2 t1 = t1w_arr[:, :, mid] pred = pred_arr[:, :, mid] gt = gt_arr[:, :, mid] fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(t1.T, cmap="gray", origin="lower") axes[0].set_title("T1w Input") axes[0].axis("off") axes[1].imshow(t1.T, cmap="gray", origin="lower") axes[1].imshow( pred.T, cmap="nipy_spectral", alpha=0.4, origin="lower", vmin=0, vmax=max(pred.max(), 1), ) axes[1].set_title("Prediction") axes[1].axis("off") axes[2].imshow(t1.T, cmap="gray", origin="lower") axes[2].imshow( gt.T, cmap="nipy_spectral", alpha=0.4, origin="lower", vmin=0, vmax=max(gt.max(), 1), ) axes[2].set_title("Ground Truth") axes[2].axis("off") fig.suptitle(title) fig.tight_layout() save_figure(fig, output_path) plt.close(fig) def plot_per_class_dice( class_dice_all: np.ndarray, class_names: list[str] | None, output_path: Path, ) -> None: """Bar chart of mean per-class Dice across all volumes.""" mean_dice = class_dice_all.mean(axis=0) std_dice = class_dice_all.std(axis=0) n = len(mean_dice) fig, ax = plt.subplots(figsize=(max(12, n * 0.3), 6)) x = np.arange(n) ax.bar(x, mean_dice, yerr=std_dice, capsize=2, alpha=0.7, color="steelblue") ax.set_xlabel("Class") ax.set_ylabel("Dice") ax.set_title("Per-Class Dice (mean ± std across volumes)") ax.set_ylim(0, 1.05) if class_names and len(class_names) == n: ax.set_xticks(x) ax.set_xticklabels(class_names, rotation=90, fontsize=6) fig.tight_layout() save_figure(fig, output_path) plt.close(fig) def main() -> None: args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) fig_dir = output_dir / "figures" fig_dir.mkdir(parents=True, exist_ok=True) # ---- Load config for n_classes and label_mapping ------------------------ config = load_config(args.config) n_classes = config.get("n_classes", 2) label_mapping = config.get("label_mapping", "binary") # Load label names for plots class_names = None if label_mapping and label_mapping != "binary": mapping_path = ( Path(__file__).parent / "label_mappings" / f"{label_mapping}-mapping.csv" ) if mapping_path.exists(): with open(mapping_path) as f: reader = csv.DictReader(f) rows = list(reader) # Build class_names indexed by 'new' column (skip background=0) name_map = {} for r in rows: new_id = int(r["new"]) if new_id > 0 and new_id not in name_map: name_map[new_id] = r.get("label", str(new_id)) class_names = [name_map.get(i, str(i)) for i in range(1, n_classes)] # Load remap function for ground truth remap_fn = None if label_mapping and label_mapping != "binary": from nobrainer.processing.dataset import _load_label_mapping remap_fn = _load_label_mapping(label_mapping) # ---- Load model --------------------------------------------------------- from nobrainer.processing.segmentation import Segmentation log.info("Loading model from %s", args.model) seg = Segmentation.load(args.model) block_shape = seg.block_shape_ or tuple(config["block_shape"]) log.info( "Model: %s, block_shape=%s, n_classes=%s", seg.base_model, block_shape, n_classes, ) # ---- Load manifest ------------------------------------------------------ pairs = load_manifest(args.manifest, split=args.split) log.info("Evaluating %d volumes from split '%s'", len(pairs), args.split) if not pairs: log.error("No volumes for split '%s'", args.split) return # ---- Evaluate each volume ----------------------------------------------- results: list[dict] = [] all_class_dice: list[np.ndarray] = [] n_samples = args.n_samples for idx, (img_path, lbl_path) in enumerate(pairs): vol_name = Path(img_path).stem log.info("Volume %d/%d: %s", idx + 1, len(pairs), vol_name) # Load and remap ground truth gt_arr = np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32) if remap_fn is not None: gt_tensor = torch.from_numpy(gt_arr) gt_arr = remap_fn(gt_tensor).numpy().astype(np.int32) elif label_mapping == "binary": gt_arr = (gt_arr > 0).astype(np.int32) t1w_arr = np.asarray(nib.load(img_path).dataobj, dtype=np.float32) # Predict (batch_size=128 to utilize GPU memory with 32³ blocks) if n_samples > 0: pred_result = seg.predict( img_path, block_shape=block_shape, n_samples=n_samples, batch_size=128, ) if isinstance(pred_result, tuple): label_img, var_img, entropy_img = pred_result nib.save(var_img, str(output_dir / f"{vol_name}_variance.nii.gz")) nib.save(entropy_img, str(output_dir / f"{vol_name}_entropy.nii.gz")) else: label_img = pred_result else: label_img = seg.predict(img_path, block_shape=block_shape, batch_size=128) pred_arr = np.asarray(label_img.dataobj, dtype=np.int32) # Per-class Dice (Eq. 19) class_dice = per_class_dice(pred_arr, gt_arr, n_classes) avg_dice = float(class_dice.mean()) all_class_dice.append(class_dice) log.info( " Avg class Dice = %.4f (min=%.4f, max=%.4f)", avg_dice, class_dice.min(), class_dice.max(), ) results.append( { "volume": vol_name, "image_path": img_path, "avg_class_dice": avg_dice, "min_class_dice": float(class_dice.min()), "max_class_dice": float(class_dice.max()), } ) # Overlay figure plot_prediction_overlay( t1w_arr, pred_arr.astype(np.float32), gt_arr.astype(np.float32), fig_dir / f"{vol_name}_overlay.png", title=f"{vol_name} — Avg Dice={avg_dice:.4f}", ) # ---- Per-class Dice bar chart ------------------------------------------- class_dice_matrix = np.array(all_class_dice) # (n_volumes, n_classes-1) plot_per_class_dice(class_dice_matrix, class_names, fig_dir / "per_class_dice.png") # ---- Save CSV with per-volume results ----------------------------------- csv_path = output_dir / "dice_scores.csv" with open(csv_path, "w", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "volume", "image_path", "avg_class_dice", "min_class_dice", "max_class_dice", ], ) writer.writeheader() writer.writerows(results) # ---- Save per-class Dice matrix ----------------------------------------- np.save(output_dir / "per_class_dice.npy", class_dice_matrix) # ---- Summary (matching Table 3 format) ---------------------------------- avg_dices = [r["avg_class_dice"] for r in results] log.info("=" * 60) log.info("Evaluation Summary (%s split, %d-class)", args.split, n_classes) log.info(" Volumes : %d", len(avg_dices)) log.info(" MC samples : %d", n_samples) log.info(" Class Dice : %.4f ± %.4f", np.mean(avg_dices), np.std(avg_dices)) log.info(" Min volume Dice : %.4f", np.min(avg_dices)) log.info(" Max volume Dice : %.4f", np.max(avg_dices)) # Per-class summary: median and range across volumes mean_per_class = class_dice_matrix.mean(axis=0) # (n_classes-1,) log.info( " Per-class Dice : median=%.4f, range=[%.4f, %.4f]", np.median(mean_per_class), mean_per_class.min(), mean_per_class.max(), ) if class_names: worst_5 = np.argsort(mean_per_class)[:5] best_5 = np.argsort(mean_per_class)[-5:][::-1] log.info( " Worst 5 classes : %s", ", ".join(f"{class_names[i]}={mean_per_class[i]:.3f}" for i in worst_5), ) log.info( " Best 5 classes : %s", ", ".join(f"{class_names[i]}={mean_per_class[i]:.3f}" for i in best_5), ) # Save per-class summary CSV per_class_csv = output_dir / "per_class_dice_summary.csv" with open(per_class_csv, "w", newline="") as f: writer = csv.writer(f) writer.writerow( [ "class_id", "class_name", "mean_dice", "median_dice", "min_dice", "max_dice", ] ) for i in range(len(mean_per_class)): name = class_names[i] if class_names else str(i + 1) col = class_dice_matrix[:, i] writer.writerow( [ i + 1, name, f"{col.mean():.4f}", f"{np.median(col):.4f}", f"{col.min():.4f}", f"{col.max():.4f}", ] ) log.info(" Per-class summary : %s", per_class_csv) log.info(" Output : %s", output_dir) log.info("=" * 60) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/05_compare_kwyk.py ================================================ #!/usr/bin/env python """Compare new model predictions against original kwyk container. Usage: python 05_compare_kwyk.py \ --new-model checkpoints/bayesian \ --kwyk-dir /path/to/kwyk \ --manifest manifest.csv \ --split test \ --output-dir results/comparison """ from __future__ import annotations import argparse import csv from pathlib import Path import subprocess import matplotlib.pyplot as plt import nibabel as nib import numpy as np from utils import compute_dice, save_figure, setup_logging log = setup_logging(__name__) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Compare new model vs original kwyk predictions", ) parser.add_argument( "--new-model", type=str, required=True, help="Path to new model directory (model.pth + croissant.json)", ) parser.add_argument( "--kwyk-dir", type=str, required=True, help="Path to original kwyk repository (containing kwyk/cli.py)", ) parser.add_argument( "--manifest", type=str, required=True, help="Path to the dataset manifest CSV", ) parser.add_argument( "--split", type=str, default="test", help="Which split to evaluate on (default: test)", ) parser.add_argument( "--output-dir", type=str, default="results/comparison", help="Directory for comparison outputs", ) return parser.parse_args() def load_manifest(manifest_path: str, split: str) -> list[tuple[str, str]]: """Load manifest CSV and return (image, label) pairs for the given split.""" pairs = [] with open(manifest_path) as f: reader = csv.DictReader(f) for row in reader: if row["split"] == split: pairs.append((row["t1w_path"], row["label_path"])) return pairs # --------------------------------------------------------------------------- # Run original kwyk prediction # --------------------------------------------------------------------------- def run_kwyk_prediction( kwyk_dir: str, infile: str, outdir: Path, ) -> Path | None: """Run original kwyk CLI to produce a prediction. Calls:: python {kwyk_dir}/kwyk/cli.py predict \\ -m bwn_multi -n 1 {infile} {outprefix} Returns the path to the prediction NIfTI, or None on failure. """ vol_stem = Path(infile).stem.replace(".nii", "") outprefix = str(outdir / f"kwyk_{vol_stem}") cmd = [ "python", str(Path(kwyk_dir) / "kwyk" / "cli.py"), "predict", "-m", "bwn_multi", "-n", "1", infile, outprefix, ] log.info("Running kwyk: %s", " ".join(cmd)) try: result = subprocess.run( cmd, capture_output=True, text=True, timeout=600, check=False, ) if result.returncode != 0: log.error("kwyk failed (rc=%d): %s", result.returncode, result.stderr) return None except subprocess.TimeoutExpired: log.error("kwyk timed out for %s", infile) return None except FileNotFoundError: log.error("kwyk CLI not found at %s", cmd[1]) return None # kwyk outputs {outprefix}_means.nii.gz or {outprefix}.nii.gz for suffix in ["_means.nii.gz", ".nii.gz", "_prediction.nii.gz"]: candidate = Path(outprefix + suffix) if candidate.exists(): return candidate log.warning("Could not find kwyk output for prefix %s", outprefix) return None # --------------------------------------------------------------------------- # Spatial correlation between uncertainty maps # --------------------------------------------------------------------------- def compute_spatial_correlation(map1: np.ndarray, map2: np.ndarray) -> float: """Compute Pearson correlation between two spatial maps. Parameters ---------- map1, map2 : np.ndarray Flattened or volumetric arrays of the same shape. Returns ------- float Pearson correlation coefficient, or 0.0 on failure. """ v1 = map1.flatten().astype(np.float64) v2 = map2.flatten().astype(np.float64) # Remove positions where both are zero mask = (v1 != 0) | (v2 != 0) if mask.sum() < 2: return 0.0 v1 = v1[mask] v2 = v2[mask] std1 = np.std(v1) std2 = np.std(v2) if std1 == 0 or std2 == 0: return 0.0 return float(np.corrcoef(v1, v2)[0, 1]) # --------------------------------------------------------------------------- # Scatter plot # --------------------------------------------------------------------------- def plot_dice_scatter( kwyk_dices: list[float], new_dices: list[float], volume_names: list[str], output_path: Path, ) -> None: """Generate scatter plot: kwyk Dice (x) vs new model Dice (y).""" fig, ax = plt.subplots(figsize=(8, 8)) ax.scatter(kwyk_dices, new_dices, alpha=0.7, edgecolors="k", s=50) # Identity line lims = [0.0, 1.0] ax.plot(lims, lims, "k--", alpha=0.3, label="y = x") ax.set_xlabel("Original kwyk Dice") ax.set_ylabel("New Model Dice") ax.set_title("Dice Comparison: Original kwyk vs New Model") ax.set_xlim(lims) ax.set_ylim(lims) ax.set_aspect("equal") ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() save_figure(fig, output_path) plt.close(fig) log.info("Scatter plot saved to %s", output_path) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: """Compare new model vs original kwyk on test volumes.""" args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) kwyk_pred_dir = output_dir / "kwyk_predictions" kwyk_pred_dir.mkdir(parents=True, exist_ok=True) # ---- Load new model ----------------------------------------------------- from nobrainer.processing.segmentation import Segmentation log.info("Loading new model from %s", args.new_model) seg = Segmentation.load(args.new_model) block_shape = seg.block_shape_ or (32, 32, 32) # ---- Load manifest ------------------------------------------------------ pairs = load_manifest(args.manifest, split=args.split) log.info( "Comparing on %d volumes from split '%s'", len(pairs), args.split, ) if not pairs: log.error("No volumes found for split '%s'. Exiting.", args.split) return # ---- Evaluate each volume ----------------------------------------------- results: list[dict] = [] kwyk_dices: list[float] = [] new_dices: list[float] = [] volume_names: list[str] = [] for idx, (img_path, lbl_path) in enumerate(pairs): vol_name = Path(img_path).stem log.info("Volume %d/%d: %s", idx + 1, len(pairs), vol_name) # Load ground truth and binarize gt_arr = np.asarray(nib.load(lbl_path).dataobj, dtype=np.float32) gt_binary = (gt_arr > 0).astype(np.float32) # ---- New model prediction ------------------------------------------- label_img = seg.predict(img_path, block_shape=block_shape) new_pred = np.asarray(label_img.dataobj, dtype=np.float32) new_binary = (new_pred > 0).astype(np.float32) new_dice = compute_dice(new_binary, gt_binary) log.info(" New model Dice = %.4f", new_dice) # ---- Original kwyk prediction --------------------------------------- kwyk_pred_path = run_kwyk_prediction(args.kwyk_dir, img_path, kwyk_pred_dir) kwyk_dice = float("nan") if kwyk_pred_path is not None and kwyk_pred_path.exists(): kwyk_arr = np.asarray( nib.load(str(kwyk_pred_path)).dataobj, dtype=np.float32 ) kwyk_binary = (kwyk_arr > 0).astype(np.float32) kwyk_dice = compute_dice(kwyk_binary, gt_binary) log.info(" kwyk Dice = %.4f", kwyk_dice) else: log.warning(" kwyk prediction not available for %s", vol_name) results.append( { "volume": vol_name, "new_dice": new_dice, "kwyk_dice": kwyk_dice, "image_path": img_path, } ) new_dices.append(new_dice) kwyk_dices.append(kwyk_dice) volume_names.append(vol_name) # ---- Save comparison CSV ------------------------------------------------ csv_path = output_dir / "comparison_table.csv" with open(csv_path, "w", newline="") as f: writer = csv.DictWriter( f, fieldnames=["volume", "new_dice", "kwyk_dice", "image_path"], ) writer.writeheader() writer.writerows(results) log.info("Comparison table saved to %s", csv_path) # ---- Scatter plot ------------------------------------------------------- # Filter out NaN kwyk dices for plotting valid_mask = [not np.isnan(kd) for kd in kwyk_dices] valid_kwyk = [kd for kd, v in zip(kwyk_dices, valid_mask) if v] valid_new = [nd for nd, v in zip(new_dices, valid_mask) if v] valid_names = [n for n, v in zip(volume_names, valid_mask) if v] if valid_kwyk: scatter_path = output_dir / "dice_scatter.png" plot_dice_scatter(valid_kwyk, valid_new, valid_names, scatter_path) else: log.warning("No valid kwyk predictions; skipping scatter plot") # ---- Spatial correlation of uncertainty maps ---------------------------- # Check if both models have uncertainty outputs new_results_dir = Path(args.new_model).parent / "results" if new_results_dir.exists(): log.info("Checking for uncertainty map correlations...") for vol_name in volume_names: new_var_path = new_results_dir / f"{vol_name}_variance.nii.gz" kwyk_var_candidates = list(kwyk_pred_dir.glob(f"kwyk_{vol_name}*variance*")) if new_var_path.exists() and kwyk_var_candidates: new_var = np.asarray( nib.load(str(new_var_path)).dataobj, dtype=np.float32 ) kwyk_var = np.asarray( nib.load(str(kwyk_var_candidates[0])).dataobj, dtype=np.float32, ) corr = compute_spatial_correlation(new_var, kwyk_var) log.info( " %s uncertainty correlation: %.4f", vol_name, corr, ) # ---- Summary ------------------------------------------------------------ log.info("=" * 60) log.info("Comparison Summary (%s split)", args.split) log.info(" Volumes compared : %d", len(results)) if valid_kwyk: log.info(" Mean kwyk Dice : %.4f", np.nanmean(kwyk_dices)) log.info(" Mean new Dice : %.4f", np.mean(new_dices)) if valid_kwyk: improvement = np.mean(valid_new) - np.mean(valid_kwyk) log.info(" Mean improvement : %+.4f", improvement) log.info(" Output directory : %s", output_dir) log.info("=" * 60) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/06_block_size_sweep.py ================================================ #!/usr/bin/env python """Sweep over block sizes to compare segmentation performance. Usage: python 06_block_size_sweep.py --manifest manifest.csv --config config.yaml python 06_block_size_sweep.py --manifest manifest.csv --config config.yaml \ --block-sizes 32 64 128 --epochs 20 --output-dir results/sweep """ from __future__ import annotations import argparse import csv from pathlib import Path import time import matplotlib.pyplot as plt import numpy as np import torch from utils import compute_dice, load_config, save_figure, setup_logging log = setup_logging(__name__) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Block size sweep for Bayesian MeshNet segmentation", ) parser.add_argument( "--manifest", type=str, required=True, help="Path to the dataset manifest CSV", ) parser.add_argument( "--config", type=str, default="config.yaml", help="Path to YAML configuration file", ) parser.add_argument( "--block-sizes", type=int, nargs="+", default=[32, 64, 128], help="Block sizes to sweep over (default: 32 64 128)", ) parser.add_argument( "--epochs", type=int, default=20, help="Number of training epochs per block size (default: 20)", ) parser.add_argument( "--output-dir", type=str, default="results/sweep", help="Directory for sweep outputs", ) return parser.parse_args() def load_manifest(manifest_path: str, split: str) -> list[tuple[str, str]]: """Load manifest CSV and return (image, label) pairs for the given split.""" pairs = [] with open(manifest_path) as f: reader = csv.DictReader(f) for row in reader: if row["split"] == split: pairs.append((row["t1w_path"], row["label_path"])) return pairs # --------------------------------------------------------------------------- # Train + evaluate for one block size # --------------------------------------------------------------------------- def train_and_evaluate( block_size: int, config: dict, train_pairs: list[tuple[str, str]], val_pairs: list[tuple[str, str]], epochs: int, ) -> dict: """Train a Bayesian MeshNet at the given block size and evaluate Dice. Returns ------- dict Keys: block_size, mean_dice, std_dice, per_volume_dices, final_loss. """ import nibabel as nib from nobrainer.models import get as get_model from nobrainer.models.bayesian.utils import accumulate_kl from nobrainer.prediction import predict from nobrainer.processing.dataset import Dataset block_shape = (block_size, block_size, block_size) n_classes = config["n_classes"] batch_size = config["batch_size"] lr = config.get("lr", 1e-4) kl_weight = config.get("kl_weight", 1.0) log.info( "Training with block_size=%d for %d epochs...", block_size, epochs, ) # Build dataset ds_train = ( Dataset.from_files(train_pairs, block_shape=block_shape, n_classes=n_classes) .batch(batch_size) .binarize(config.get("label_mapping", "binary")) ) train_loader = ds_train.dataloader # Build model model_args = { "n_classes": n_classes, "filters": config.get("filters", 96), "receptive_field": config.get("receptive_field", 37), "dropout_rate": config.get("dropout_rate", 0.25), } model = get_model("bayesian_meshnet")(**model_args) from nobrainer.gpu import get_device device = get_device() model = model.to(device) ce_loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Training loop final_loss = 0.0 for epoch in range(epochs): model.train() epoch_loss = 0.0 n_batches = 0 for batch in train_loader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) elif isinstance(batch, (list, tuple)): images = batch[0].to(device) labels = batch[1].to(device) else: raise TypeError(f"Unsupported batch type: {type(batch)}") if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer.zero_grad() pred = model(images) loss = ce_loss(pred, labels) + kl_weight * accumulate_kl(model) loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 final_loss = epoch_loss / max(n_batches, 1) if (epoch + 1) % 5 == 0 or epoch == epochs - 1: log.info( " block_size=%d, epoch %d/%d, loss=%.6f", block_size, epoch + 1, epochs, final_loss, ) # Evaluate on validation set model.eval() per_volume_dices: list[float] = [] for img_path, lbl_path in val_pairs: gt_arr = np.asarray(nib.load(lbl_path).dataobj, dtype=np.float32) gt_binary = (gt_arr > 0).astype(np.float32) pred_img = predict( inputs=img_path, model=model, block_shape=block_shape, batch_size=4, return_labels=True, ) pred_arr = np.asarray(pred_img.dataobj, dtype=np.float32) pred_binary = (pred_arr > 0).astype(np.float32) dice = compute_dice(pred_binary, gt_binary) per_volume_dices.append(dice) mean_dice = float(np.mean(per_volume_dices)) if per_volume_dices else 0.0 std_dice = float(np.std(per_volume_dices)) if per_volume_dices else 0.0 log.info( " block_size=%d: Dice=%.4f +/- %.4f (%d volumes)", block_size, mean_dice, std_dice, len(per_volume_dices), ) return { "block_size": block_size, "mean_dice": mean_dice, "std_dice": std_dice, "per_volume_dices": per_volume_dices, "final_loss": final_loss, } # --------------------------------------------------------------------------- # Bar chart # --------------------------------------------------------------------------- def plot_block_size_comparison( sweep_results: list[dict], output_path: Path, ) -> None: """Generate bar chart: block_size on x, Dice on y with error bars.""" block_sizes = [r["block_size"] for r in sweep_results] means = [r["mean_dice"] for r in sweep_results] stds = [r["std_dice"] for r in sweep_results] fig, ax = plt.subplots(figsize=(8, 6)) x = np.arange(len(block_sizes)) bars = ax.bar( x, means, yerr=stds, capsize=5, color="steelblue", edgecolor="black", alpha=0.8, ) ax.set_xlabel("Block Size") ax.set_ylabel("Dice Score") ax.set_title("Block Size Sweep — Bayesian MeshNet") ax.set_xticks(x) ax.set_xticklabels([str(bs) for bs in block_sizes]) ax.set_ylim(0.0, 1.0) ax.grid(axis="y", alpha=0.3) # Annotate bars with mean values for bar, mean in zip(bars, means): ax.text( bar.get_x() + bar.get_width() / 2.0, bar.get_height() + 0.02, f"{mean:.3f}", ha="center", va="bottom", fontsize=10, ) fig.tight_layout() save_figure(fig, output_path) plt.close(fig) log.info("Bar chart saved to %s", output_path) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: """Run block size sweep and generate comparison outputs.""" args = parse_args() t_start = time.time() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # ---- Load config -------------------------------------------------------- config = load_config(args.config) log.info("Config loaded from %s", args.config) log.info( "Block size sweep: sizes=%s, epochs=%d", args.block_sizes, args.epochs, ) # ---- Load manifest ------------------------------------------------------ train_pairs = load_manifest(args.manifest, split="train") val_pairs = load_manifest(args.manifest, split="val") log.info( "Manifest: %d train, %d val volumes", len(train_pairs), len(val_pairs), ) if not train_pairs: log.error("No training volumes found. Exiting.") return if not val_pairs: log.warning("No validation volumes found; Dice will be empty.") # ---- Run sweep ---------------------------------------------------------- sweep_results: list[dict] = [] for block_size in args.block_sizes: result = train_and_evaluate( block_size=block_size, config=config, train_pairs=train_pairs, val_pairs=val_pairs, epochs=args.epochs, ) sweep_results.append(result) # ---- Save comparison CSV ------------------------------------------------ csv_path = output_dir / "block_size_comparison.csv" with open(csv_path, "w", newline="") as f: writer = csv.DictWriter( f, fieldnames=["block_size", "mean_dice", "std_dice", "final_loss"], ) writer.writeheader() for r in sweep_results: writer.writerow( { "block_size": r["block_size"], "mean_dice": r["mean_dice"], "std_dice": r["std_dice"], "final_loss": r["final_loss"], } ) log.info("Comparison CSV saved to %s", csv_path) # ---- Bar chart ---------------------------------------------------------- chart_path = output_dir / "block_size_comparison.png" plot_block_size_comparison(sweep_results, chart_path) # ---- Summary ------------------------------------------------------------ elapsed = time.time() - t_start best = max(sweep_results, key=lambda r: r["mean_dice"]) log.info("=" * 60) log.info("Block Size Sweep Complete") log.info(" Block sizes tested: %s", args.block_sizes) log.info(" Epochs per size : %d", args.epochs) log.info( " Best block size : %d (Dice=%.4f)", best["block_size"], best["mean_dice"] ) for r in sweep_results: log.info( " block_size=%3d: Dice=%.4f +/- %.4f, loss=%.6f", r["block_size"], r["mean_dice"], r["std_dice"], r["final_loss"], ) log.info(" Output directory : %s", output_dir) log.info(" Elapsed time : %.1f s", elapsed) log.info("=" * 60) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/ARCHITECTURE.md ================================================ # KWYK Architecture Verification This document records how the original kwyk model architecture was verified against the paper, source code, and trained model weights. ## Paper Reference McClure P. et al., "Knowing What You Know in Brain Segmentation Using Bayesian Deep Neural Networks", Front. Neuroinform. 2019. https://doi.org/10.3389/fninf.2019.00067 ## Three Model Variants | Variant | kwyk ID | Conv Layer | Dropout | MC at inference | |---------|---------|-----------|---------|-----------------| | MAP | `bwn` (all_50_wn) | VWN | Bernoulli (fixed) | No | | MC Bernoulli Dropout (BD) | `bwn_multi` (all_50_bwn_09_multi) | VWN | Bernoulli (fixed) | Yes | | Spike-and-Slab Dropout (SSD) | `bvwn_multi_prior` (all_50_bvwn_multi_prior) | VWN | Concrete (learned) | Yes | ## Architecture: Variational Weight Normalization (VWN) Conv ### Verified from trained model variables Downloaded `neuronets/kwyk:latest-cpu` Docker container and inspected the SavedModel variables for the SSD model (`all_50_bvwn_multi_prior`): ``` layer_1/conv3d/v:0: [3, 3, 3, 1, 96] # raw weight for WN layer_1/conv3d/g:0: [1, 1, 1, 1, 96] # gain per filter layer_1/conv3d/kernel_a:0: [3, 3, 3, 1, 96] # sigma = |kernel_a| layer_1/conv3d/bias_m:0: [96] # bias mean layer_1/conv3d/bias_a:0: [96] # bias sigma = |bias_a| layer_1/concrete_dropout/p:0: [96] # per-filter dropout rate ``` This confirms **weight normalization** (`v`, `g`) is used, not the direct `μ` parameterization described in the paper's equations. ### Key finding: all 3 models are independently trained VWN models All 3 saved models have the **same layer structure** (`v`, `g`, `kernel_a`, `bias_m`, `bias_a`) — including the MAP model (`all_50_wn`). They are **not** weight-sharing variants; they were trained independently: | Model | Total variables | Extra per layer | Timestamp | |-------|----------------|----------------|-----------| | all_50_wn (MAP) | 41 | — | 1555341859 | | all_50_bwn_09_multi (BD) | 41 | — | 1555963478 | | all_50_bvwn_multi_prior (SSD) | 48 | `concrete_dropout/p` | 1556816070 | The MAP and BD models have identical parameterization (both have `kernel_a` for learned sigma). The only difference is whether MC sampling is enabled at inference time. The SSD model additionally has 7 `concrete_dropout/p` parameters (one per conv layer) for learned per-filter dropout rates. ### Verified from source code Commit `4dd379c` in `neuronets/kwyk` repo (Patrick McClure, 2019-02-28): **`nobrainer/models/vwn_conv.py`** — `_Conv.build()`: ```python self.v = self.add_variable(name='v', ...) self.g = self.add_variable(name='g', ...) self.v_norm = tf.nn.l2_normalize(self.v, [...]) self.kernel_m = tf.multiply(self.g, self.v_norm, name='kernel_m') self.kernel_a = self.add_variable(name='kernel_a', ...) self.kernel_sigma = tf.abs(self.kernel_a, name='kernel_sigma') ``` **`_Conv.call()`** — local reparameterization trick: ```python outputs_mean = self._convolution_op(inputs, self.kernel_m) outputs_var = self._convolution_op(tf.square(inputs), tf.square(self.kernel_sigma)) outputs_e = tf.random_normal(shape=tf.shape(self.g)) # MC path: output = outputs_mean + tf.sqrt(outputs_var + 1e-8) * outputs_e ``` **`nobrainer/models/bayesian_dropout.py`** defines: - `bernoulli_dropout()` — standard MC dropout (bwn/bwn_multi) - `concrete_dropout()` — learned per-filter rate (bvwn_multi_prior) - `gaussian_dropout()` — not used in final models ### Paper vs Implementation discrepancy The paper (Section 2.2.3.2) describes the mean weight as `μ_{f,t}` (Eq. 13), but the actual implementation uses weight normalization: - `kernel_m = g · v / ||v||` (Salimans & Kingma 2016) - This is a reparameterization of the mean that aids training stability - The sigma is the same in both: `σ_{f,t} = |kernel_a_{f,t}|` The paper's equations are in terms of the effective mean (`μ`), which is computed via WN but isn't stored directly as a parameter. ## KL Divergence (Eq. 16-18) Two terms per filter: 1. **Bernoulli KL** for concrete dropout (Eq. 17): `KL(q_p || p_prior) = p·log(p/p_prior) + (1-p)·log((1-p)/(1-p_prior))` Prior: `p_prior = 0.5` 2. **Gaussian KL** per weight (Eq. 18): `KL(N(μ,σ) || N(μ_prior, σ_prior)) = log(σ_prior/σ) + (σ² + (μ-μ_prior)²)/(2σ²_prior) - 1/2` Prior: `μ_prior = 0, σ_prior = 0.1` ## Network Architecture (Table 2) 8 layers of dilated 3×3×3 convolutions: - Layers 1-3: dilation=1, 96 filters, ReLU - Layer 4: dilation=2 - Layer 5: dilation=4 - Layer 6: dilation=8 - Layer 7: dilation=1 - Layer 8 (logits): 1×1×1, 50 filters, Softmax Receptive field = 37 voxels. ## Our Implementation `nobrainer.models.bayesian.vwn_layers.FFGConv3d`: - Parameters: `v`, `g` (weight normalization), `kernel_a` (sigma), `bias_m`, `bias_a` - Forward: local reparameterization trick matching the original - KL: Eq. 18 with `prior_mu=0, prior_sigma=0.1` `nobrainer.models.bayesian.vwn_layers.ConcreteDropout3d`: - Learned `p` per filter via concrete relaxation (Eq. 10) - KL: Eq. 17 with `prior_p=0.5` `nobrainer.models.bayesian.kwyk_meshnet.KWYKMeshNet`: - Registered as `"kwyk_meshnet"` in model registry (no Pyro dependency) - `dropout_type="bernoulli"` for bwn/bwn_multi - `dropout_type="concrete"` for bvwn_multi_prior (SSD) - `mc=True/False` flag controls stochastic vs deterministic inference ## Training Details (from paper) - Optimizer: Adam, lr=1e-4 - Batch size: 32 (4 GPUs × 8) - Block shape: 32×32×32 - Data: 11,480 T1 sMRI volumes, 50-class FreeSurfer parcellation - MC samples at inference: 10 ================================================ FILE: scripts/kwyk_reproduction/README.md ================================================ # KWYK Brain Extraction Reproduction Reproduce the kwyk brain extraction study (McClure et al., Frontiers in Neuroinformatics 2019) using the refactored PyTorch nobrainer. **Reference**: https://www.frontiersin.org/journals/neuroinformatics/articles/10.3389/fninf.2019.00067/full ## Current Status The reproduction pipeline is **code-complete and CI-verified** (smoke test + small-scale 20-epoch training on T4 GPU with real OpenNeuro data). Full-scale reproduction with 50+ epochs and 100+ subjects has **not yet been run**. See "Next Steps" below. ## Quick Setup ```bash # Option A: Use the orchestrator script (creates venv automatically) cd scripts/kwyk_reproduction ./run.sh --smoke-test # Quick verification (5 volumes, 2 epochs) ./run.sh # Full pipeline # Option B: Manual setup uv venv --python 3.14 && source .venv/bin/activate uv pip install -e "../../[bayesian,versioning,dev]" monai pyro-ppl datalad matplotlib pyyaml scipy uv tool install git-annex # required for DataLad content retrieval ``` ## Programmatic API The dataset fetching is also available as a library: ```python from nobrainer.datasets.openneuro import ( install_derivatives, find_subject_pairs, write_manifest, ) # Clone fmriprep derivatives (metadata only, fast) ds = install_derivatives("ds000114", "/tmp/data") # Discover + download T1w + aparc+aseg pairs per subject pairs = find_subject_pairs(ds) # Write manifest CSV with train/val/test split write_manifest(pairs, "manifest.csv") ``` ## Pipeline Steps ### Step 1: Assemble Dataset ```bash python 01_assemble_dataset.py --datasets ds000114 --output-csv manifest.csv ``` Downloads T1w + aparc+aseg volumes from OpenNeuro fmriprep derivatives via DataLad. Start with 1 dataset (~10 subjects) for smoke testing, then scale: ```bash # Scale to more datasets python 01_assemble_dataset.py \ --datasets ds000114 ds000228 ds002609 ds001021 ds002105 \ --output-csv manifest.csv --conform ``` ### Step 2: Train Deterministic MeshNet (Warm-Start Foundation) ```bash python 02_train_meshnet.py --manifest manifest.csv --epochs 50 ``` Trains a standard MeshNet with kwyk-matching parameters (filters=96, block_shape=32³, lr=0.0001). This model's weights serve as the mean priors for the Bayesian model in Step 3. **Output**: `checkpoints/meshnet/model.pth`, `figures/meshnet_learning_curve.png` ### Step 3: Train All Model Variants All 3 kwyk models use **Variational Weight Normalization (VWN)** convolutions with per-weight learned sigma and the local reparameterization trick. They differ only in the dropout layer. See [ARCHITECTURE.md](ARCHITECTURE.md) for the full verification against the paper, code, and trained weights. Use `--variant` to select: ```bash # 3a. MC Bernoulli dropout (bwn_multi) — VWN conv + dropout at inference python 03_train_bayesian.py \ --manifest manifest.csv --variant bwn_multi \ --warmstart checkpoints/meshnet --output-dir checkpoints/bwn_multi \ --epochs 50 # 3b. Spike-and-slab dropout (bvwn_multi_prior) — VWN conv + concrete dropout python 03_train_bayesian.py \ --manifest manifest.csv --variant bvwn_multi_prior \ --warmstart checkpoints/meshnet --output-dir checkpoints/bvwn_multi_prior \ --epochs 50 ``` | Variant | kwyk ID | Conv | Dropout | MC at inference | |---------|---------|------|---------|-----------------| | `bwn` (step 2) | all_50_wn | VWN | Bernoulli (fixed) | No (MAP) | | `bwn_multi` | all_50_bwn_09_multi | VWN | Bernoulli (fixed) | Yes | | `bvwn_multi_prior` | all_50_bvwn_multi_prior | VWN | Concrete (learned) | Yes | The warm-start decomposes deterministic Conv3d weights into weight normalization form (`v`, `g`) for the VWN layers. **Output**: `checkpoints//model.pth`, `checkpoints//croissant.json`, `checkpoints//learning_curve.png` ### Step 4: Evaluate ```bash # Evaluate each variant for variant in meshnet bwn_multi bvwn_multi_prior bayesian_gaussian; do python 04_evaluate.py \ --model checkpoints/$variant/model.pth \ --manifest manifest.csv --split test --n-samples 10 \ --output-dir results/$variant done ``` Computes per-volume Dice, saves variance + entropy maps as NIfTI. ### Step 5: Compare with Original KWYK ```bash python 05_compare_kwyk.py \ --new-model checkpoints/bvwn_multi_prior/model.pth \ --kwyk-dir ../../kwyk \ --manifest manifest.csv ``` Runs the original kwyk container on the same test volumes and generates a Dice scatter plot + comparison table. **Note**: This requires the kwyk container at `../../kwyk` to be functional. The comparison is only meaningful after the Bayesian model has been trained to convergence (Steps 2-3). ### Step 6: Block Size Sweep (Optional) ```bash python 06_block_size_sweep.py --manifest manifest.csv --block-sizes 32 64 128 ``` ## Next Steps for GPU Execution The following steps should be performed on a machine with a GPU (e.g., the EC2 GPU runner or a local workstation): ### Phase 1: Smoke Test (15 minutes, any GPU) ```bash ./run.sh --smoke-test ``` Verify the pipeline works end-to-end with tiny models. Check `figures/` for learning curves showing loss decrease. ### Phase 2: Small-Scale Training (1-2 hours, T4 16GB) ```bash python 01_assemble_dataset.py --datasets ds000114 --output-csv manifest.csv python 02_train_meshnet.py --manifest manifest.csv --epochs 20 # Train all 3 Bayesian variants for variant in bwn_multi bvwn_multi_prior bayesian_gaussian; do python 03_train_bayesian.py --manifest manifest.csv \ --variant $variant --warmstart checkpoints/meshnet \ --output-dir checkpoints/$variant --epochs 20 done ``` **Expected**: Validation Dice ≥0.80 for brain extraction on 10 subjects. ### Phase 3: Full Reproduction (8-24 hours, V100 16GB+) ```bash python 01_assemble_dataset.py \ --datasets ds000114 ds000228 ds002609 ds001021 ds002105 \ --output-csv manifest.csv --conform python 02_train_meshnet.py --manifest manifest.csv --epochs 50 for variant in bwn_multi bvwn_multi_prior bayesian_gaussian; do python 03_train_bayesian.py --manifest manifest.csv \ --variant $variant --warmstart checkpoints/meshnet \ --output-dir checkpoints/$variant --epochs 50 done python 05_compare_kwyk.py \ --new-model checkpoints/bvwn_multi_prior/model.pth \ --kwyk-dir ../../kwyk --manifest manifest.csv ``` **Target**: Validation Dice ≥0.90 (kwyk achieved 0.97+ with 11,000 subjects). ### Phase 4: Scale and Optimize To approach kwyk's full performance: 1. **Add more datasets**: Add OpenNeuro dataset IDs to the `--datasets` list 2. **Block size sweep**: `python 06_block_size_sweep.py --block-sizes 32 64 128` 3. **SynthSeg augmentation**: `python 03_train_bayesian.py --augmentation mixed` 4. **Longer training**: Increase `--epochs` to 100+ ### Phase 5: Automated Hyperparameter Optimization Use nobrainer's autoresearch loop to explore hyperparameters overnight: ```bash # Set up the research directory mkdir -p research/kwyk_bayesian cp checkpoints/bayesian/model.pth research/kwyk_bayesian/ cat > research/kwyk_bayesian/program.md << 'EOF' ## Exploration Targets - kl_weight: 1e-5, 1e-4, 1e-3, 1e-2, 1e-1 - dropout_rate: 0.0, 0.1, 0.25, 0.5 - filters: 71, 96, 128 - prior_type: standard_normal, laplace - block_shape: 32, 64 - learning_rate: 1e-5, 5e-5, 1e-4, 5e-4 ## Success Criterion - val_dice improvement over current best - Max 30 min per experiment EOF # Launch overnight optimization nobrainer research run \ --working-dir research/kwyk_bayesian \ --model-family bayesian_meshnet \ --max-experiments 20 \ --budget-hours 8 ``` The autoresearch loop will: 1. Propose hyperparameter changes (via LLM or random grid) 2. Train, evaluate, keep improvements, revert failures 3. Save the best model with full Croissant-ML provenance Check results: `cat research/kwyk_bayesian/run_summary.md` ## Configuration Edit `config.yaml` to change default hyperparameters: | Parameter | Default | kwyk Original | Notes | |-----------|---------|---------------|-------| | filters | 96 | 96 | Feature maps per layer | | receptive_field | 37 | 37 | Dilation schedule [1,1,1,2,4,8,1] | | block_shape | [32,32,32] | [32,32,32] | Patch size for training | | lr | 0.0001 | 0.0001 | Adam learning rate | | kl_weight | 1.0 | implicit | KL divergence scaling | | dropout_rate | 0.25 | 0.25 | Spatial dropout | | prior_type | spike_and_slab | spike_and_slab | SSD: π·N(0,0.001) + (1-π)·N(0,1) | | spike_sigma | 0.001 | ~0 | Spike component σ | | slab_sigma | 1.0 | ~1 | Slab component σ | | prior_pi | 0.5 | 0.5 | Spike probability | | n_classes | 2 | 50 | Binary brain extraction (kwyk used 50-class) | | label_mapping | binary | N/A | Also supports 6/50/115-class | ## Label Mappings The `label_mappings/` directory contains CSVs that remap FreeSurfer aparc+aseg codes to target classes: - **binary**: Any non-zero → 1 (brain extraction) - **6-class**: Coarse parcellation (WM, cortex, ventricles, cerebellum, etc.) - **50-class**: Matches original kwyk study - **115-class**: Fine-grained parcellation ## GPU Requirements | Task | Block Size | Filters | GPU Memory | Time | |------|-----------|---------|------------|------| | Smoke test | 16³ | 16 | ≥4 GB | ~5 min | | Small training | 32³ | 96 | ≥16 GB | ~2 hr | | Full reproduction | 32³ | 96 | ≥16 GB | ~24 hr | | Block sweep 64³ | 64³ | 96 | ≥16 GB | ~4 hr | | Full-brain 256³ | 256³ | 96 | ≥24 GB | N/A | ## What Has Been Verified - [x] Smoke test on EC2 T4 GPU (2 epochs, all 3 variants) - [x] Small-scale training on EC2 T4 GPU (20 epochs, 10 subjects from ds000114, all 3 variants) - [x] DataLad + git-annex data pipeline (OpenNeuro fmriprep derivatives) - [x] Spike-and-slab prior, MC dropout, and Gaussian Bayesian variants ## What Has NOT Been Done Yet - [ ] Full-scale reproduction with 50+ epochs and 100+ subjects - [ ] Comparison against kwyk container (requires converged model) - [ ] Block size sweep results - [ ] SynthSeg augmentation experiments - [ ] Autoresearch hyperparameter optimization ================================================ FILE: scripts/kwyk_reproduction/__init__.py ================================================ ================================================ FILE: scripts/kwyk_reproduction/build_kwyk_manifest.py ================================================ #!/usr/bin/env python """Build a manifest CSV from the original KWYK dataset (PAC brain volumes). The KWYK dataset contains paired files: - pac__orig.nii.gz (T1w image) - pac__aseg.nii.gz (FreeSurfer aparc+aseg label) Usage: python build_kwyk_manifest.py --data-dir ../data/SharedData/segmentation/freesurfer_asegs \ --output-csv kwyk_manifest.csv --n-subjects 100 """ from __future__ import annotations import argparse import csv from pathlib import Path import random def main(): parser = argparse.ArgumentParser(description="Build manifest from KWYK PAC dataset") parser.add_argument( "--data-dir", type=str, required=True, help="Directory containing pac_*_orig.nii.gz and pac_*_aseg.nii.gz files", ) parser.add_argument( "--output-csv", type=str, default="kwyk_manifest.csv", help="Output manifest CSV path", ) parser.add_argument( "--n-subjects", type=int, default=None, help="Number of subjects to include (default: all)", ) parser.add_argument( "--split", nargs=3, type=int, default=[80, 10, 10], help="Train/val/test split percentages (default: 80 10 10)", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for shuffling and split (default: 42)", ) args = parser.parse_args() data_dir = Path(args.data_dir).resolve() if not data_dir.is_dir(): raise SystemExit(f"Data directory not found: {data_dir}") # Find all paired subjects orig_files = sorted(data_dir.glob("pac_*_orig.nii.gz")) pairs = [] for orig in orig_files: # Extract subject ID: pac__orig.nii.gz -> stem = orig.name # pac_123_orig.nii.gz subj_id = stem.replace("pac_", "").replace("_orig.nii.gz", "") aseg = data_dir / f"pac_{subj_id}_aseg.nii.gz" if aseg.exists(): pairs.append((subj_id, str(orig), str(aseg))) print(f"Found {len(pairs)} paired subjects in {data_dir}") if not pairs: raise SystemExit("No paired (orig, aseg) files found.") # Shuffle and subsample random.seed(args.seed) random.shuffle(pairs) if args.n_subjects is not None: pairs = pairs[: args.n_subjects] print(f"Subsampled to {len(pairs)} subjects") # Split n = len(pairs) train_pct, val_pct, test_pct = args.split assert train_pct + val_pct + test_pct == 100 n_train = int(n * train_pct / 100) n_val = int(n * val_pct / 100) # rest goes to test splits = ["train"] * n_train + ["val"] * n_val + ["test"] * (n - n_train - n_val) # Write manifest output_csv = Path(args.output_csv) with open(output_csv, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["subject_id", "dataset_id", "t1w_path", "label_path", "split"]) for (subj_id, orig_path, aseg_path), split in zip(pairs, splits): writer.writerow([f"pac_{subj_id}", "kwyk", orig_path, aseg_path, split]) # Summary from collections import Counter split_counts = Counter(splits) print(f"Manifest written to {output_csv}") print( f" train: {split_counts['train']}, " f"val: {split_counts['val']}, test: {split_counts['test']}" ) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/config.yaml ================================================ # KWYK Brain Segmentation Reproduction - Default Configuration # Based on: McClure et al., Frontiers in Neuroinformatics 2019 # # The original kwyk study trained 3 model variants, all using # Variational Weight Normalization (VWN) convolutions: # 1. bwn — VWN conv + Bernoulli dropout OFF at inference (MAP) # 2. bwn_multi — VWN conv + Bernoulli dropout ON at inference (MC) # 3. bvwn_multi_prior — VWN conv + Concrete dropout (learned rates) # --------------------------------------------------------------------------- # Shared architecture (all variants use identical VWN structure) # --------------------------------------------------------------------------- filters: 96 receptive_field: 37 # dilation schedule [1,1,1,2,4,8,1] dropout_rate: 0.25 sigma_init: 0.0001 # initial |kernel_a| for VWN weight sigma # Training block_shape: [32, 32, 32] lr: 0.0001 batch_size: 32 # per-GPU; auto-optimized when GPU available n_classes: 50 # 50-class FreeSurfer parcellation (matching paper) label_mapping: 50-class # uses label_mappings/50-class-mapping.csv # Class weighting — the original paper used unweighted CrossEntropyLoss. # Set to "auto" or "median_frequency" for experiments with class balancing. # Options: null (paper default), "auto" (inverse frequency), "median_frequency" class_weight_method: null # Loss function: "cross_entropy" (paper default) or "dice_ce" (Dice + weighted CE) loss: cross_entropy # Warm-start (deterministic → VWN) pretrain_epochs: 50 bayesian_epochs: 50 # Inference n_samples: 10 # MC inference samples # Data augmentation (profile: none, light, standard, heavy) augmentation_profile: standard # Zarr optimization (optional — convert NIfTI to Zarr for faster I/O) # zarr_store: null # path like "data/brain_store.zarr" # zarr_chunk_shape: [32, 32, 32] # Stride for patch extraction (null = random, or [sD, sH, sW] for grid) # stride: null # Data assembly datasets: - ds000114 - ds000228 - ds002609 split: [80, 10, 10] # train/val/test percentages # --------------------------------------------------------------------------- # Model variant configurations (matching original kwyk) # --------------------------------------------------------------------------- variants: # 1. bwn — VWN conv, Bernoulli dropout, MC OFF at inference (MAP) bwn: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: false description: "VWN MeshNet, Bernoulli dropout, deterministic inference (MAP)" # 2. bwn_multi — VWN conv, Bernoulli dropout, MC ON at inference bwn_multi: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: true description: "VWN MeshNet, MC Bernoulli dropout at inference" # 3. bvwn_multi_prior — VWN conv, Concrete dropout (learned per-filter rates) bvwn_multi_prior: model: kwyk_meshnet dropout_type: concrete concrete_temperature: 0.02 concrete_init_p: 0.9 mc_at_inference: true description: "VWN MeshNet, Concrete dropout (learned rates)" # --------------------------------------------------------------------------- # Presets for different scales # --------------------------------------------------------------------------- # Quick binary brain extraction (for initial testing) binary_preset: n_classes: 2 label_mapping: binary class_weight_method: null # Smoke test overrides smoke_test: filters: 16 block_shape: [16, 16, 16] pretrain_epochs: 1 bayesian_epochs: 1 batch_size: 2 n_samples: 3 n_classes: 2 label_mapping: binary class_weight_method: null ================================================ FILE: scripts/kwyk_reproduction/config_kwyk_smoke.yaml ================================================ # KWYK PAC Dataset Smoke Test — 50-class parcellation, 100 subjects # Based on: McClure et al., Frontiers in Neuroinformatics 2019 # --------------------------------------------------------------------------- # Shared architecture (all variants use identical VWN structure) # --------------------------------------------------------------------------- filters: 96 receptive_field: 37 # dilation schedule [1,1,1,2,4,8,1] dropout_rate: 0.25 sigma_init: 0.0001 # initial |kernel_a| for VWN weight sigma # Training block_shape: [32, 32, 32] lr: 0.0001 batch_size: 256 n_classes: 50 # 50-class FreeSurfer parcellation label_mapping: 50-class patches_per_volume: 50 # random patches per volume per epoch (GPU utilization) # Warm-start (deterministic → VWN) pretrain_epochs: 50 bayesian_epochs: 50 # Validation val_dice_freq: 5 # full-volume Dice every N epochs (block-level metrics every epoch) # Inference n_samples: 10 # MC inference samples # Zarr store (fast chunk-aligned I/O, created by slurm_convert_zarr.sbatch) zarr_store: /orcd/scratch/orcd/013/satra/kwyk_reproduction/data/kwyk_500.zarr # Data augmentation augmentation: real # real, synthetic, mixed # Data assembly datasets: - kwyk split: [80, 10, 10] # train/val/test percentages # --------------------------------------------------------------------------- # Model variant configurations (matching original kwyk) # --------------------------------------------------------------------------- variants: bwn: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: false description: "VWN MeshNet, Bernoulli dropout, deterministic inference (MAP)" bwn_multi: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: true description: "VWN MeshNet, MC Bernoulli dropout at inference" bvwn_multi_prior: model: kwyk_meshnet dropout_type: concrete concrete_temperature: 0.02 concrete_init_p: 0.9 mc_at_inference: true description: "VWN MeshNet, Concrete dropout (learned rates)" ================================================ FILE: scripts/kwyk_reproduction/convert_zarr_shard.py ================================================ #!/usr/bin/env python """Convert one shard of NIfTI volumes to a pre-created Zarr3 store. Usage: python convert_zarr_shard.py --manifest manifest.csv --zarr-store data/store.zarr \ --shard-idx 0 --subjects-per-shard 50 Called by SLURM job array — each task writes one shard independently. """ from __future__ import annotations import argparse import csv from pathlib import Path import time import nibabel as nib import numpy as np import zarr def main(): parser = argparse.ArgumentParser() parser.add_argument("--manifest", required=True) parser.add_argument("--zarr-store", required=True) parser.add_argument("--shard-idx", type=int, required=True) parser.add_argument("--subjects-per-shard", type=int, default=50) parser.add_argument( "--create", action="store_true", help="Create the store (only shard 0 should do this)", ) args = parser.parse_args() # Read manifest pairs = [] subject_ids = [] with open(args.manifest) as f: for row in csv.DictReader(f): pairs.append((row["t1w_path"], row["label_path"])) subject_ids.append(row["subject_id"]) n_subjects = len(pairs) sps = args.subjects_per_shard start = args.shard_idx * sps end = min(start + sps, n_subjects) if start >= n_subjects: print(f"Shard {args.shard_idx}: no subjects (start={start} >= {n_subjects})") return store_path = Path(args.zarr_store) if args.create: # Create the store and arrays (only one task does this) D, H, W = 256, 256, 256 n_shards = (n_subjects + sps - 1) // sps store = zarr.open_group(str(store_path), mode="w") store.create_array( "images", shape=(n_subjects, D, H, W), chunks=(1, 32, 32, 32), shards=(sps, D, H, W), dtype=np.float32, ) store.create_array( "labels", shape=(n_subjects, D, H, W), chunks=(1, 32, 32, 32), shards=(sps, D, H, W), dtype=np.int32, ) store.attrs["n_subjects"] = n_subjects store.attrs["subject_ids"] = subject_ids store.attrs["volume_shape"] = [D, H, W] print(f"Created store: {store_path} ({n_subjects} subjects, {n_shards} shards)") # Write partition JSON import json partitions = {"train": [], "val": [], "test": []} with open(args.manifest) as f: for row in csv.DictReader(f): partitions[row["split"]].append(row["subject_id"]) part_path = str(store_path) + "_partition.json" with open(part_path, "w") as f: json.dump({"partitions": partitions}, f, indent=2) for k, v in partitions.items(): print(f" {k}: {len(v)} subjects") else: # Open existing store in append mode store = zarr.open_group(str(store_path), mode="r+") images_arr = store["images"] labels_arr = store["labels"] t0 = time.time() for i in range(start, end): img_path, lbl_path = pairs[i] # PAC data is already 256³ @ 1mm uint8/int32 — no conform needed img_data = np.asarray(nib.load(img_path).dataobj, dtype=np.float32) lbl_data = np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32) images_arr[i] = img_data[:256, :256, :256] labels_arr[i] = lbl_data[:256, :256, :256] if (i - start + 1) % 10 == 0: elapsed = time.time() - t0 rate = (i - start + 1) / elapsed print( f" Shard {args.shard_idx}: {i - start + 1}/{end - start} " f"({rate:.1f} vol/s, {elapsed:.0f}s)" ) elapsed = time.time() - t0 print( f"Shard {args.shard_idx}: wrote {end - start} volumes in {elapsed:.1f}s " f"({(end - start) / elapsed:.1f} vol/s)" ) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/experiments/01_20260330_eval_deterministic/README.md ================================================ # Experiment 01: Evaluate Bayesian models in deterministic mode ## Rationale All 3 Bayesian variants show zero Dice during MC evaluation despite training loss decreasing from ~3.8 to ~2.2. The prediction code calls `model(tensor)` which defaults to `mc=True` in KWYKMeshNet.forward(), activating local reparameterization noise and dropout. With only 20 epochs of Bayesian training, this noise may overwhelm the learned signal. **Hypothesis:** The model weights have learned meaningful representations, but MC inference noise destroys the output. Evaluating with `mc=False` should show non-zero Dice. ## Plan 1. Write a quick eval script that loads each Bayesian checkpoint and runs prediction with `mc=False` (deterministic forward pass) 2. Compare per-class Dice between mc=True and mc=False 3. No retraining needed — just evaluate existing checkpoints ## Tasks - [x] Write eval script with mc=False support - [x] Run on existing 20-epoch checkpoints - [ ] Compare results ================================================ FILE: scripts/kwyk_reproduction/experiments/01_20260330_eval_deterministic/eval_deterministic.py ================================================ #!/usr/bin/env python """Evaluate Bayesian models in deterministic mode (mc=False). Quick diagnostic: do the weights contain useful information that MC noise destroys? """ from __future__ import annotations import csv from pathlib import Path import sys import nibabel as nib import numpy as np import torch sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from utils import setup_logging # noqa: E402 log = setup_logging(__name__) def per_class_dice(pred: np.ndarray, gt: np.ndarray, n_classes: int) -> np.ndarray: """Per-class Dice for classes 1..n_classes-1.""" dice = np.zeros(n_classes - 1) for c in range(1, n_classes): p = pred == c g = gt == c inter = (p & g).sum() total = p.sum() + g.sum() dice[c - 1] = 2.0 * inter / total if total > 0 else 1.0 return dice def predict_volume(model, img_path, block_shape, mc=False): """Block-based prediction on a single volume.""" from nobrainer.prediction import _extract_blocks, _pad_to_multiple, _stitch_blocks from nobrainer.training import get_device device = get_device() img = nib.load(str(img_path)) arr = np.asarray(img.dataobj, dtype=np.float32) orig_shape = arr.shape[:3] padded, pad = _pad_to_multiple(arr, block_shape) blocks, grid = _extract_blocks(padded, block_shape) model = model.to(device) model.eval() all_preds = [] with torch.no_grad(): for start in range(0, len(blocks), 4): chunk = blocks[start : start + 4] tensor = torch.from_numpy(chunk[:, None]).to(device) out = model(tensor, mc=mc) labels = out.argmax(dim=1, keepdim=True).float() all_preds.append(labels.cpu().numpy()) block_preds = np.concatenate(all_preds, axis=0) full = _stitch_blocks(block_preds, grid, block_shape, pad, orig_shape, 1)[0] return full.astype(np.int32) def main(): from nobrainer.processing.dataset import _load_label_mapping from nobrainer.processing.segmentation import Segmentation work_dir = Path(__file__).parent.parent.parent manifest_path = work_dir / "kwyk_manifest.csv" remap_fn = _load_label_mapping("50-class") n_classes = 50 # Load test pairs pairs = [] with open(manifest_path) as f: for row in csv.DictReader(f): if row["split"] == "test": pairs.append((row["t1w_path"], row["label_path"])) log.info("Test volumes: %d", len(pairs)) # Evaluate each variant variants = [ "kwyk_smoke_bwn_multi", "kwyk_smoke_bvwn_multi_prior", "kwyk_smoke_bayesian_gaussian", ] results = [] for variant in variants: ckpt_dir = work_dir / "checkpoints" / variant if not (ckpt_dir / "model.pth").exists(): log.warning("Skipping %s — no checkpoint", variant) continue log.info("=== %s ===", variant) seg = Segmentation.load(ckpt_dir) model = seg.model_ block_shape = seg.block_shape_ or (32, 32, 32) for mc_mode in [False, True]: mode_name = "mc" if mc_mode else "deterministic" all_dice = [] for idx, (img_path, lbl_path) in enumerate(pairs[:3]): # first 3 for speed pred_arr = predict_volume(model, img_path, block_shape, mc=mc_mode) gt_arr = np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32) gt_arr = remap_fn(torch.from_numpy(gt_arr)).numpy() cd = per_class_dice(pred_arr, gt_arr, n_classes) avg = float(cd.mean()) all_dice.append(avg) log.info( " [%s] vol %d: avg_dice=%.4f max=%.4f", mode_name, idx + 1, avg, cd.max(), ) mean_dice = float(np.mean(all_dice)) results.append( {"variant": variant, "mode": mode_name, "mean_dice": mean_dice} ) log.info(" [%s] MEAN: %.4f", mode_name, mean_dice) # Save results out_path = Path(__file__).parent / "results.csv" with open(out_path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=["variant", "mode", "mean_dice"]) w.writeheader() w.writerows(results) log.info("Results saved to %s", out_path) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/experiments/01_20260330_eval_deterministic/results_summary.md ================================================ # Experiment 01 Results ## Finding **Both deterministic and MC modes produce zero Dice for all 3 Bayesian variants.** | Variant | Deterministic | MC | |---|---|---| | bwn_multi | 0.0000 | 0.0000 | | bvwn_multi_prior | 0.0000 | 0.0001 | | bayesian_gaussian | 0.0000 | 0.0000 | ## Conclusion The hypothesis was wrong — the issue is NOT MC noise destroying learned signals. The weights themselves contain no useful information. Possible causes: 1. **Warm-start transfer failure**: weights not properly transferred from MeshNet to KWYKMeshNet 2. **Bayesian training destroying warm-start**: ELBO loss / weight perturbation undoing the transferred weights 3. **Architecture mismatch**: MeshNet vs KWYKMeshNet parameter shapes may not align ## Next Steps - Investigate warm-start transfer code - Check if Bayesian model immediately after warm-start (before training) can segment - Compare model architectures between MeshNet and KWYKMeshNet ================================================ FILE: scripts/kwyk_reproduction/experiments/01_20260330_eval_deterministic/run.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp01-det #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=00:30:00 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err set -euo pipefail cd /orcd/scratch/orcd/013/satra/kwyk_reproduction source /orcd/data/satra/002/projects/nobrainer/venvs/nobrainer/bin/activate echo "=== Experiment 01: Deterministic eval of Bayesian models ===" echo "Started: $(date)" python experiments/01_20260330_eval_deterministic/eval_deterministic.py echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/02_20260330_binary_bayesian/README.md ================================================ # Experiment 02: Binary (2-class) Bayesian training ## Rationale 50-class parcellation is a hard problem. Binary brain extraction (brain vs background) is much simpler and was used in the original smoke tests. If Bayesian models can learn binary segmentation, the issue is with label complexity, not the Bayesian architecture. ## Plan 1. Use 5 subjects, binary label mapping, 20 epochs 2. Train MeshNet → warm-start bwn_multi (simplest Bayesian variant) 3. Evaluate both mc=True and mc=False 4. If binary works, the 50-class zero Dice is likely a capacity/epochs issue ## Tasks - [ ] Create binary config - [ ] Train MeshNet (binary, 5 subjects, 20 epochs) - [ ] Train bwn_multi (binary, warm-start, 20 epochs) - [ ] Evaluate both modes ================================================ FILE: scripts/kwyk_reproduction/experiments/02_20260330_binary_bayesian/config.yaml ================================================ filters: 96 receptive_field: 37 dropout_rate: 0.25 sigma_init: 0.0001 block_shape: [32, 32, 32] lr: 0.0001 batch_size: 32 n_classes: 2 label_mapping: binary pretrain_epochs: 20 bayesian_epochs: 20 n_samples: 10 variants: bwn_multi: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: true ================================================ FILE: scripts/kwyk_reproduction/experiments/02_20260330_binary_bayesian/eval_binary.py ================================================ #!/usr/bin/env python """Evaluate binary Bayesian model in both mc=True and mc=False modes.""" from __future__ import annotations import csv from pathlib import Path import sys import nibabel as nib import numpy as np import torch sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from utils import compute_dice, setup_logging # noqa: E402 log = setup_logging(__name__) EXP_DIR = Path(__file__).parent WORK_DIR = EXP_DIR.parent.parent def predict_volume(model, img_path, block_shape, mc=False): """Block-based prediction with mc control.""" from nobrainer.prediction import _extract_blocks, _pad_to_multiple, _stitch_blocks from nobrainer.training import get_device device = get_device() img = nib.load(str(img_path)) arr = np.asarray(img.dataobj, dtype=np.float32) orig_shape = arr.shape[:3] padded, pad = _pad_to_multiple(arr, block_shape) blocks, grid = _extract_blocks(padded, block_shape) model = model.to(device) model.eval() all_preds = [] with torch.no_grad(): for start in range(0, len(blocks), 4): chunk = blocks[start : start + 4] tensor = torch.from_numpy(chunk[:, None]).to(device) if hasattr(model, "forward") and "mc" in model.forward.__code__.co_varnames: out = model(tensor, mc=mc) else: out = model(tensor) labels = out.argmax(dim=1, keepdim=True).float() all_preds.append(labels.cpu().numpy()) block_preds = np.concatenate(all_preds, axis=0) full = _stitch_blocks(block_preds, grid, block_shape, pad, orig_shape, 1)[0] return (full > 0).astype(np.float32) def main(): from nobrainer.processing.segmentation import Segmentation manifest_path = WORK_DIR / "kwyk_sanity_manifest.csv" # Load test pairs pairs = [] with open(manifest_path) as f: for row in csv.DictReader(f): if row["split"] == "test": pairs.append((row["t1w_path"], row["label_path"])) log.info("Test volumes: %d", len(pairs)) results = [] for variant in ["meshnet", "bwn_multi"]: ckpt_dir = EXP_DIR / "checkpoints" / variant if not (ckpt_dir / "model.pth").exists(): log.warning("Skipping %s — no checkpoint", variant) continue seg = Segmentation.load(ckpt_dir) model = seg.model_ block_shape = seg.block_shape_ or (32, 32, 32) mc_modes = [False] if variant == "meshnet" else [False, True] for mc_mode in mc_modes: mode_name = "mc" if mc_mode else "deterministic" dices = [] for idx, (img_path, lbl_path) in enumerate(pairs): pred = predict_volume(model, img_path, block_shape, mc=mc_mode) gt = ( np.asarray(nib.load(lbl_path).dataobj, dtype=np.float32) > 0 ).astype(np.float32) dice = compute_dice(pred, gt) dices.append(dice) log.info( " [%s/%s] vol %d: Dice=%.4f", variant, mode_name, idx + 1, dice ) mean_d = float(np.mean(dices)) results.append( {"variant": variant, "mode": mode_name, "mean_dice": f"{mean_d:.4f}"} ) log.info(" [%s/%s] MEAN DICE: %.4f", variant, mode_name, mean_d) out_path = EXP_DIR / "results.csv" with open(out_path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=["variant", "mode", "mean_dice"]) w.writeheader() w.writerows(results) log.info("Results saved to %s", out_path) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/experiments/02_20260330_binary_bayesian/eval_only.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp02-eval #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=00:15:00 #SBATCH --output=slurm-eval-%j.out #SBATCH --error=slurm-eval-%j.err set -euo pipefail cd /orcd/scratch/orcd/013/satra/kwyk_reproduction source /orcd/data/satra/002/projects/nobrainer/venvs/nobrainer/bin/activate echo "=== Experiment 02: Binary eval ===" python experiments/02_20260330_binary_bayesian/eval_binary.py echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/02_20260330_binary_bayesian/run.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp02-bin #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=01:00:00 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/02_20260330_binary_bayesian" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" MANIFEST="$WORK_DIR/kwyk_sanity_manifest.csv" # 5 subjects CONFIG="$EXP_DIR/config.yaml" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 02: Binary Bayesian (5 subjects, 20 epochs) ===" echo "Started: $(date)" # Train MeshNet (binary) echo "=== Step 1: MeshNet (binary, 20 epochs) ===" python 02_train_meshnet.py --manifest "$MANIFEST" --config "$CONFIG" \ --output-dir "$EXP_DIR/checkpoints/meshnet" --epochs 20 # Train bwn_multi (binary, warm-start from MeshNet) echo "=== Step 2: bwn_multi (binary, 20 epochs) ===" python 03_train_bayesian.py --manifest "$MANIFEST" --config "$CONFIG" \ --variant bwn_multi --warmstart "$EXP_DIR/checkpoints/meshnet" \ --output-dir "$EXP_DIR/checkpoints/bwn_multi" --epochs 20 # Quick eval: deterministic and MC echo "=== Step 3: Eval (deterministic + MC) ===" python experiments/02_20260330_binary_bayesian/eval_binary.py echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/03_20260330_warmstart_diagnostic/README.md ================================================ # Experiment 03: Warm-start transfer diagnostic ## Rationale Experiment 01 showed that Bayesian model weights have zero Dice even in deterministic mode. This means the warm-start transfer from MeshNet to KWYKMeshNet may not be working, OR the Bayesian training loop destroys transferred weights. ## Plan 1. Load trained MeshNet checkpoint 2. Create KWYKMeshNet and run warm-start transfer 3. Evaluate KWYKMeshNet immediately BEFORE any Bayesian training (mc=False) 4. If Dice > 0: warm-start works, Bayesian training is the problem 5. If Dice = 0: warm-start transfer is broken 6. Also compare parameter counts and shapes between MeshNet and KWYKMeshNet ## Tasks - [ ] Compare architectures - [ ] Evaluate warm-started model before training - [ ] Check transfer log messages ================================================ FILE: scripts/kwyk_reproduction/experiments/03_20260330_warmstart_diagnostic/diagnose.py ================================================ #!/usr/bin/env python """Diagnose warm-start transfer from MeshNet to KWYKMeshNet.""" from __future__ import annotations import csv from pathlib import Path import sys import nibabel as nib import numpy as np import torch sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from utils import setup_logging # noqa: E402 log = setup_logging(__name__) WORK_DIR = Path(__file__).parent.parent.parent EXP_DIR = Path(__file__).parent def predict_volume_simple(model, img_path, block_shape, mc=False): """Block-based prediction.""" from nobrainer.prediction import _extract_blocks, _pad_to_multiple, _stitch_blocks from nobrainer.training import get_device device = get_device() img = nib.load(str(img_path)) arr = np.asarray(img.dataobj, dtype=np.float32) orig_shape = arr.shape[:3] padded, pad = _pad_to_multiple(arr, block_shape) blocks, grid = _extract_blocks(padded, block_shape) model = model.to(device) model.eval() all_preds = [] with torch.no_grad(): for start in range(0, len(blocks), 4): chunk = blocks[start : start + 4] tensor = torch.from_numpy(chunk[:, None]).to(device) try: out = model(tensor, mc=mc) except TypeError: out = model(tensor) labels = out.argmax(dim=1, keepdim=True).float() all_preds.append(labels.cpu().numpy()) block_preds = np.concatenate(all_preds, axis=0) full = _stitch_blocks(block_preds, grid, block_shape, pad, orig_shape, 1)[0] return full.astype(np.int32) def main(): from nobrainer.models import get as get_model from nobrainer.processing.dataset import _load_label_mapping from nobrainer.processing.segmentation import Segmentation remap_fn = _load_label_mapping("50-class") n_classes = 50 block_shape = (32, 32, 32) # Load test pairs pairs = [] with open(WORK_DIR / "kwyk_manifest.csv") as f: for row in csv.DictReader(f): if row["split"] == "test": pairs.append((row["t1w_path"], row["label_path"])) # Use first test volume img_path, lbl_path = pairs[0] gt_arr = remap_fn( torch.from_numpy(np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32)) ).numpy() # ---- Step 1: Load trained MeshNet and eval ---- log.info("=== Step 1: Evaluate trained MeshNet ===") seg = Segmentation.load(WORK_DIR / "checkpoints" / "kwyk_smoke_meshnet") det_model = seg.model_ log.info("MeshNet type: %s", type(det_model).__name__) log.info("MeshNet params: %d", sum(p.numel() for p in det_model.parameters())) pred = predict_volume_simple(det_model, img_path, block_shape, mc=False) # per-class dice dices = [] for c in range(1, n_classes): p = pred == c g = gt_arr == c inter = (p & g).sum() total = p.sum() + g.sum() dices.append(2.0 * inter / total if total > 0 else 1.0) log.info("MeshNet Dice: mean=%.4f, max=%.4f", np.mean(dices), np.max(dices)) # ---- Step 2: Create KWYKMeshNet and warm-start ---- log.info("=== Step 2: Warm-start KWYKMeshNet from MeshNet ===") model_args = { "n_classes": n_classes, "filters": 96, "receptive_field": 37, "dropout_type": "bernoulli", "dropout_rate": 0.25, "sigma_init": 0.0001, } kwyk_factory = get_model("kwyk_meshnet") kwyk_model = kwyk_factory(**model_args) log.info("KWYKMeshNet type: %s", type(kwyk_model).__name__) log.info("KWYKMeshNet params: %d", sum(p.numel() for p in kwyk_model.parameters())) # Print layer comparison log.info("--- MeshNet layers ---") for name, param in det_model.named_parameters(): log.info(" %s: %s", name, param.shape) log.info("--- KWYKMeshNet layers ---") for name, param in kwyk_model.named_parameters(): log.info(" %s: %s", name, param.shape) # Run warm-start from nobrainer.models.bayesian.warmstart import warmstart_kwyk_from_deterministic meshnet_ckpt = WORK_DIR / "checkpoints" / "kwyk_smoke_meshnet" / "model.pth" n_transferred = warmstart_kwyk_from_deterministic(kwyk_model, str(meshnet_ckpt)) log.info("Transferred %d layers", n_transferred) # ---- Step 3: Eval KWYKMeshNet BEFORE any Bayesian training ---- log.info("=== Step 3: Evaluate warm-started KWYKMeshNet (mc=False) ===") pred = predict_volume_simple(kwyk_model, img_path, block_shape, mc=False) dices = [] for c in range(1, n_classes): p = pred == c g = gt_arr == c inter = (p & g).sum() total = p.sum() + g.sum() dices.append(2.0 * inter / total if total > 0 else 1.0) log.info( "KWYKMeshNet (warm-start, mc=False) Dice: mean=%.4f, max=%.4f", np.mean(dices), np.max(dices), ) # Also test mc=True log.info("=== Step 4: Evaluate warm-started KWYKMeshNet (mc=True) ===") pred = predict_volume_simple(kwyk_model, img_path, block_shape, mc=True) dices = [] for c in range(1, n_classes): p = pred == c g = gt_arr == c inter = (p & g).sum() total = p.sum() + g.sum() dices.append(2.0 * inter / total if total > 0 else 1.0) log.info( "KWYKMeshNet (warm-start, mc=True) Dice: mean=%.4f, max=%.4f", np.mean(dices), np.max(dices), ) # ---- Step 5: Check what 03_train_bayesian does ---- log.info("=== Step 5: Check how training script loads warm-start ===") # Read the training script to see if it uses warmstart_kwyk_from_deterministic train_script = WORK_DIR / "03_train_bayesian.py" with open(train_script) as f: content = f.read() if "warmstart_kwyk_from_deterministic" in content: log.info("Training script uses warmstart_kwyk_from_deterministic") elif "warmstart_bayesian_from_deterministic" in content: log.info( "Training script uses warmstart_bayesian_from_deterministic (WRONG for KWYK!)" ) else: log.info("No warmstart function found in training script — check manually") # Grep for the relevant line for line_no, line in enumerate(content.split("\n"), 1): if "warmstart" in line.lower() and not line.strip().startswith("#"): log.info(" Line %d: %s", line_no, line.strip()) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/experiments/03_20260330_warmstart_diagnostic/results_summary.md ================================================ # Experiment 03 Results ## Key Finding: Warm-start transfer bug — sorted key ordering mismatch **MeshNet Dice (trained):** mean=0.0132, max=0.4738 **KWYKMeshNet (warm-start, mc=False):** mean=0.0006, max=0.0101 **KWYKMeshNet (warm-start, mc=True):** mean=0.0006, max=0.0103 Only 5 of 7+1 layers transferred successfully. ## Root Cause `warmstart_kwyk_from_deterministic()` sorts state dict keys alphabetically: ```python det_convs = [(k, v) for k, v in sorted(state.items()) if "weight" in k and v.ndim == 5] ``` This produces ordering: `classifier.weight, encoder.0, encoder.1, ...` But KWYKMeshNet FFGConv3d layers are: `layer_0, layer_1, ...` (no classifier — it's a regular Conv3d) So `classifier.weight [50,96,1,1,1]` pairs with `layer_0.conv [96,1,3,3,3]` → shape mismatch! Then `encoder.0 [96,1,3,3,3]` pairs with `layer_1.conv [96,96,3,3,3]` → shape mismatch! Then `encoder.1 [96,96,3,3,3]` pairs with `layer_2.conv [96,96,3,3,3]` → OK (but wrong weights!) Result: 5 layers "transferred" but with wrong weight assignments (encoder.1→layer_2 instead of encoder.0→layer_0), and first two layers get random initialization. ## Fix Filter out the classifier weight before pairing, OR use explicit name matching. ## Conclusion The Bayesian zero Dice is caused by a broken warm-start. The model starts from mostly-random weights and 20 epochs isn't enough to learn from scratch. ================================================ FILE: scripts/kwyk_reproduction/experiments/03_20260330_warmstart_diagnostic/run.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp03-ws #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=00:15:00 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err set -euo pipefail cd /orcd/scratch/orcd/013/satra/kwyk_reproduction source /orcd/data/satra/002/projects/nobrainer/venvs/nobrainer/bin/activate echo "=== Experiment 03: Warm-start diagnostic ===" python experiments/03_20260330_warmstart_diagnostic/diagnose.py echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/04_20260330_fixed_warmstart/README.md ================================================ # Experiment 04: Fix warm-start and verify Bayesian learning ## Rationale Experiment 03 found that `warmstart_kwyk_from_deterministic()` has a key ordering bug: sorted() puts classifier.weight before encoder.X, causing all layer pairings to be offset. Fix the transfer, verify Dice is preserved, then train Bayesian. ## Plan 1. Fix warm-start: filter classifier from det_convs, transfer it separately 2. Verify fixed warm-start preserves MeshNet Dice 3. Train bwn_multi for 20 epochs with fixed warm-start 4. Evaluate in both mc=False and mc=True modes 5. Use 5 subjects for speed (sanity manifest) ## Tasks - [ ] Fix warm-start function - [ ] Verify transfer preserves Dice - [ ] Train and evaluate Bayesian ================================================ FILE: scripts/kwyk_reproduction/experiments/04_20260330_fixed_warmstart/run.py ================================================ #!/usr/bin/env python """Fix warm-start, verify transfer, train Bayesian, evaluate.""" from __future__ import annotations import csv from pathlib import Path import sys import nibabel as nib import numpy as np import torch import torch.nn as nn sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from utils import setup_logging # noqa: E402 log = setup_logging(__name__) WORK_DIR = Path(__file__).parent.parent.parent EXP_DIR = Path(__file__).parent def fixed_warmstart_kwyk(kwyk_model, det_weights_path): """Fixed warm-start: filter classifier, match encoder layers correctly.""" from nobrainer.models.bayesian.vwn_layers import FFGConv3d state = torch.load(det_weights_path, weights_only=True) # Separate encoder convs from classifier encoder_convs = [] classifier_weight = None classifier_bias = None for k in sorted(state.keys()): v = state[k] if k == "classifier.weight" and v.ndim == 5: classifier_weight = v elif k == "classifier.bias": classifier_bias = v elif "weight" in k and v.ndim == 5: encoder_convs.append((k, v)) log.info("Found %d encoder convs + classifier in MeshNet", len(encoder_convs)) # Transfer encoder convs to FFGConv3d layers kwyk_convs = [ (n, m) for n, m in kwyk_model.named_modules() if isinstance(m, FFGConv3d) ] transferred = 0 for (det_name, det_w), (kwyk_name, kwyk_conv) in zip(encoder_convs, kwyk_convs): if det_w.shape != kwyk_conv.v.shape: log.warning( "Shape mismatch: %s %s vs %s.v %s", det_name, det_w.shape, kwyk_name, kwyk_conv.v.shape, ) continue kwyk_conv.v.data.copy_(det_w) norms = det_w.flatten(1).norm(dim=1).view_as(kwyk_conv.g) kwyk_conv.g.data.copy_(norms) transferred += 1 log.info(" %s -> %s", det_name, kwyk_name) # Transfer classifier if classifier_weight is not None and hasattr(kwyk_model, "classifier"): kwyk_model.classifier.weight.data.copy_(classifier_weight) if classifier_bias is not None: kwyk_model.classifier.bias.data.copy_(classifier_bias) log.info(" classifier transferred") transferred += 1 log.info("Total transferred: %d layers", transferred) return transferred def predict_volume(model, img_path, block_shape, mc=False): """Block-based prediction.""" from nobrainer.prediction import _extract_blocks, _pad_to_multiple, _stitch_blocks from nobrainer.training import get_device device = get_device() img = nib.load(str(img_path)) arr = np.asarray(img.dataobj, dtype=np.float32) orig_shape = arr.shape[:3] padded, pad = _pad_to_multiple(arr, block_shape) blocks, grid = _extract_blocks(padded, block_shape) model = model.to(device) model.eval() all_preds = [] with torch.no_grad(): for start in range(0, len(blocks), 4): chunk = blocks[start : start + 4] tensor = torch.from_numpy(chunk[:, None]).to(device) try: out = model(tensor, mc=mc) except TypeError: out = model(tensor) labels = out.argmax(dim=1, keepdim=True).float() all_preds.append(labels.cpu().numpy()) block_preds = np.concatenate(all_preds, axis=0) full = _stitch_blocks(block_preds, grid, block_shape, pad, orig_shape, 1)[0] return full.astype(np.int32) def per_class_dice(pred, gt, n_classes): dices = [] for c in range(1, n_classes): p = pred == c g = gt == c inter = (p & g).sum() total = p.sum() + g.sum() dices.append(2.0 * inter / total if total > 0 else 1.0) return np.array(dices) def main(): from nobrainer.models import get as get_model from nobrainer.processing.dataset import Dataset, _load_label_mapping from nobrainer.processing.segmentation import Segmentation n_classes = 50 block_shape = (32, 32, 32) remap_fn = _load_label_mapping("50-class") # Test volume pairs = [] with open(WORK_DIR / "kwyk_sanity_manifest.csv") as f: for row in csv.DictReader(f): if row["split"] == "test": pairs.append((row["t1w_path"], row["label_path"])) img_path, lbl_path = pairs[0] gt_arr = remap_fn( torch.from_numpy(np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32)) ).numpy() # ---- Step 1: Verify MeshNet baseline ---- log.info("=== Step 1: MeshNet baseline ===") seg = Segmentation.load(WORK_DIR / "checkpoints" / "sanity_meshnet") det_model = seg.model_ pred = predict_volume(det_model, img_path, block_shape) cd = per_class_dice(pred, gt_arr, n_classes) log.info("MeshNet: mean=%.4f, max=%.4f", cd.mean(), cd.max()) # ---- Step 2: Fixed warm-start ---- log.info("=== Step 2: Fixed warm-start ===") kwyk_factory = get_model("kwyk_meshnet") kwyk_model = kwyk_factory( n_classes=n_classes, filters=96, receptive_field=37, dropout_type="bernoulli", dropout_rate=0.25, sigma_init=0.0001, ) meshnet_ckpt = WORK_DIR / "checkpoints" / "sanity_meshnet" / "model.pth" fixed_warmstart_kwyk(kwyk_model, meshnet_ckpt) # Eval immediately pred = predict_volume(kwyk_model, img_path, block_shape, mc=False) cd = per_class_dice(pred, gt_arr, n_classes) log.info( "KWYKMeshNet fixed warm-start (mc=False): mean=%.4f, max=%.4f", cd.mean(), cd.max(), ) pred = predict_volume(kwyk_model, img_path, block_shape, mc=True) cd = per_class_dice(pred, gt_arr, n_classes) log.info( "KWYKMeshNet fixed warm-start (mc=True): mean=%.4f, max=%.4f", cd.mean(), cd.max(), ) # ---- Step 3: Train Bayesian (5 subjects, 20 epochs) ---- log.info("=== Step 3: Train Bayesian with fixed warm-start ===") manifest = WORK_DIR / "kwyk_sanity_manifest.csv" label_mapping = "50-class" train_pairs = [] val_pairs = [] with open(manifest) as f: for row in csv.DictReader(f): p = (row["t1w_path"], row["label_path"]) if row["split"] == "train": train_pairs.append(p) elif row["split"] == "val": val_pairs.append(p) ds_train = ( Dataset.from_files(train_pairs, block_shape=block_shape, n_classes=n_classes) .batch(32) .binarize(label_mapping) ) from nobrainer.training import get_device device = get_device() kwyk_model = kwyk_model.to(device) optimizer = torch.optim.Adam(kwyk_model.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() for epoch in range(20): kwyk_model.train() epoch_loss = 0.0 n_batches = 0 for batch in ds_train.dataloader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) else: images = batch[0].to(device) labels = batch[1].to(device) if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer.zero_grad() pred_t = kwyk_model(images, mc=True) loss = criterion(pred_t, labels) loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 avg_loss = epoch_loss / max(n_batches, 1) msg = f"Epoch {epoch + 1}/20: loss={avg_loss:.4f}" # Eval every 5 epochs if (epoch + 1) % 5 == 0: pred = predict_volume(kwyk_model, img_path, block_shape, mc=False) cd = per_class_dice(pred, gt_arr, n_classes) msg += f" dice_det={cd.mean():.4f}/{cd.max():.4f}" pred = predict_volume(kwyk_model, img_path, block_shape, mc=True) cd_mc = per_class_dice(pred, gt_arr, n_classes) msg += f" dice_mc={cd_mc.mean():.4f}/{cd_mc.max():.4f}" log.info(msg) # ---- Final eval ---- log.info("=== Final evaluation ===") for mc_mode in [False, True]: pred = predict_volume(kwyk_model, img_path, block_shape, mc=mc_mode) cd = per_class_dice(pred, gt_arr, n_classes) mode = "mc" if mc_mode else "det" log.info("Final [%s]: mean=%.4f, max=%.4f", mode, cd.mean(), cd.max()) # Save model torch.save(kwyk_model.state_dict(), EXP_DIR / "kwyk_fixed_warmstart.pth") log.info("Model saved to %s", EXP_DIR / "kwyk_fixed_warmstart.pth") if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/experiments/04_20260330_fixed_warmstart/run.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp04-fix #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=00:30:00 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err set -euo pipefail cd /orcd/scratch/orcd/013/satra/kwyk_reproduction source /orcd/data/satra/002/projects/nobrainer/venvs/nobrainer/bin/activate echo "=== Experiment 04: Fixed warm-start + Bayesian training ===" python experiments/04_20260330_fixed_warmstart/run.py echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/05_20260330_kwyk_from_scratch/README.md ================================================ # Experiment 05: KWYKMeshNet from scratch (no warm-start) ## Rationale Experiments 03-04 showed warm-start doesn't transfer well. But the Bayesian training also shows zero Dice after 20 epochs from scratch. The question: can KWYKMeshNet learn AT ALL with mc=False (deterministic)? If not, the VWN architecture + CrossEntropyLoss may have a fundamental issue. We also test: (a) mc=False during training, (b) binary labels for simplicity. ## Plan 1. Train KWYKMeshNet (50-class, mc=True during training, 5 subj, 50 epochs) 2. Train KWYKMeshNet (50-class, mc=FALSE during training, 5 subj, 50 epochs) 3. Train KWYKMeshNet (binary, mc=False, 5 subj, 50 epochs) 4. Compare: does turning off MC during training help? Does binary help? ================================================ FILE: scripts/kwyk_reproduction/experiments/05_20260330_kwyk_from_scratch/results_summary.md ================================================ # Experiment 05 Results ## Key Finding: mc=False during training is REQUIRED for KWYKMeshNet to learn | Condition | mc_train | n_classes | Final Loss | Final Dice (det) mean/max | |---|---|---|---|---| | A | True | 50 | 3.387 | 0.0000/0.0006 | | **B** | **False** | **50** | **2.620** | **0.0019/0.0936** | | C | True | 2 | 1.011 | 0.0000 | | D | False | 2 | 0.910 | 0.0001 | ## Analysis 1. **mc=True kills training**: Conditions A and C (mc=True) both converge to zero Dice despite loss decreasing. The local reparameterization noise from FFGConv3d prevents stable gradient flow. The loss landscape becomes too noisy. 2. **mc=False allows learning**: Condition B (mc=False, 50-class) achieves 9.4% Dice on the best class — comparable to the deterministic MeshNet at similar epoch count. The VWN weight normalization itself is fine; the stochastic sampling is the issue. 3. **Binary fails for both**: Possibly an issue with binary evaluation or the model having too many parameters for a 2-class problem (3M params for binary). 4. **Loss instability with mc=True**: Condition A shows wild loss swings (1.3 to 5.3) because each forward pass samples different weights. mc=False gives stable loss. ## Conclusion The Bayesian training should use `mc=False` for the forward pass during gradient computation, and only enable `mc=True` at inference time for uncertainty estimation. This is the standard approach: train with deterministic weights, use stochastic inference. The current code passes `mc=True` during training which prevents learning. ## Recommendation Fix `03_train_bayesian.py` to call `model(images, mc=False)` during training, and only use `mc=True` for validation MC Dice evaluation. ================================================ FILE: scripts/kwyk_reproduction/experiments/05_20260330_kwyk_from_scratch/run.py ================================================ #!/usr/bin/env python """Train KWYKMeshNet from scratch: mc=True vs mc=False, 50-class vs binary.""" from __future__ import annotations import csv from pathlib import Path import sys import time import nibabel as nib import numpy as np import torch import torch.nn as nn sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from utils import setup_logging # noqa: E402 log = setup_logging(__name__) WORK_DIR = Path(__file__).parent.parent.parent EXP_DIR = Path(__file__).parent def predict_volume(model, img_path, block_shape, mc=False): from nobrainer.prediction import _extract_blocks, _pad_to_multiple, _stitch_blocks from nobrainer.training import get_device device = get_device() img = nib.load(str(img_path)) arr = np.asarray(img.dataobj, dtype=np.float32) orig_shape = arr.shape[:3] padded, pad = _pad_to_multiple(arr, block_shape) blocks, grid = _extract_blocks(padded, block_shape) model = model.to(device) model.eval() all_preds = [] with torch.no_grad(): for start in range(0, len(blocks), 4): chunk = blocks[start : start + 4] tensor = torch.from_numpy(chunk[:, None]).to(device) out = model(tensor, mc=mc) all_preds.append(out.argmax(dim=1, keepdim=True).float().cpu().numpy()) block_preds = np.concatenate(all_preds, axis=0) return _stitch_blocks(block_preds, grid, block_shape, pad, orig_shape, 1)[0].astype( np.int32 ) def per_class_dice(pred, gt, n_classes): dices = [] for c in range(1, n_classes): p = pred == c g = gt == c inter = (p & g).sum() total = p.sum() + g.sum() dices.append(2.0 * inter / total if total > 0 else 1.0) return np.array(dices) def binary_dice(pred, gt): pred = (pred > 0).astype(bool) gt = (gt > 0).astype(bool) inter = (pred & gt).sum() total = pred.sum() + gt.sum() return 2.0 * inter / total if total > 0 else 1.0 def train_kwyk(name, n_classes, label_mapping, mc_train, epochs=50): from nobrainer.models import get as get_model from nobrainer.processing.dataset import Dataset, _load_label_mapping from nobrainer.training import get_device log.info( "=== %s: n_classes=%d, mc_train=%s, epochs=%d ===", name, n_classes, mc_train, epochs, ) block_shape = (32, 32, 32) device = get_device() # Load data train_pairs, val_pairs = [], [] with open(WORK_DIR / "kwyk_sanity_manifest.csv") as f: for row in csv.DictReader(f): p = (row["t1w_path"], row["label_path"]) if row["split"] == "train": train_pairs.append(p) elif row["split"] == "test": val_pairs.append(p) # use test as val for sanity ds = ( Dataset.from_files(train_pairs, block_shape=block_shape, n_classes=n_classes) .batch(32) .binarize(label_mapping) ) # Remap function for eval remap_fn = None if label_mapping and label_mapping != "binary": remap_fn = _load_label_mapping(label_mapping) # Create model model = get_model("kwyk_meshnet")( n_classes=n_classes, filters=96, receptive_field=37, dropout_type="bernoulli", dropout_rate=0.25, sigma_init=0.0001, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() # Test volume img_path, lbl_path = val_pairs[0] gt_raw = np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32) if remap_fn: gt_arr = remap_fn(torch.from_numpy(gt_raw)).numpy() else: gt_arr = (gt_raw > 0).astype(np.int32) t0 = time.time() for epoch in range(epochs): model.train() epoch_loss = 0.0 n_batches = 0 for batch in ds.dataloader: if isinstance(batch, dict): images = batch["image"].to(device) labels = batch["label"].to(device) else: images = batch[0].to(device) labels = batch[1].to(device) if labels.ndim == images.ndim and labels.shape[1] == 1: labels = labels.squeeze(1) if labels.dtype in (torch.float32, torch.float64): labels = labels.long() optimizer.zero_grad() pred = model(images, mc=mc_train) loss = criterion(pred, labels) loss.backward() optimizer.step() epoch_loss += loss.item() n_batches += 1 avg_loss = epoch_loss / max(n_batches, 1) msg = f" Epoch {epoch + 1}/{epochs}: loss={avg_loss:.4f}" if (epoch + 1) % 10 == 0 or epoch == 0: pred_vol = predict_volume(model, img_path, block_shape, mc=False) if n_classes == 2: d = binary_dice(pred_vol, gt_arr) msg += f" dice_det={d:.4f}" else: cd = per_class_dice(pred_vol, gt_arr, n_classes) msg += f" dice_det={cd.mean():.4f}/{cd.max():.4f}" log.info(msg) elapsed = time.time() - t0 log.info(" Completed in %.1fs", elapsed) # Final eval pred_vol = predict_volume(model, img_path, block_shape, mc=False) if n_classes == 2: d = binary_dice(pred_vol, gt_arr) log.info(" FINAL [det]: dice=%.4f", d) else: cd = per_class_dice(pred_vol, gt_arr, n_classes) log.info(" FINAL [det]: mean=%.4f, max=%.4f", cd.mean(), cd.max()) if mc_train: pred_vol = predict_volume(model, img_path, block_shape, mc=True) if n_classes == 2: d = binary_dice(pred_vol, gt_arr) log.info(" FINAL [mc]: dice=%.4f", d) else: cd = per_class_dice(pred_vol, gt_arr, n_classes) log.info(" FINAL [mc]: mean=%.4f, max=%.4f", cd.mean(), cd.max()) return model def main(): # A: 50-class, mc=True during training (current default) train_kwyk("A_50class_mcTrue", 50, "50-class", mc_train=True, epochs=50) # B: 50-class, mc=False during training (deterministic forward) train_kwyk("B_50class_mcFalse", 50, "50-class", mc_train=False, epochs=50) # C: Binary, mc=True during training train_kwyk("C_binary_mcTrue", 2, "binary", mc_train=True, epochs=50) # D: Binary, mc=False during training train_kwyk("D_binary_mcFalse", 2, "binary", mc_train=False, epochs=50) if __name__ == "__main__": main() ================================================ FILE: scripts/kwyk_reproduction/experiments/05_20260330_kwyk_from_scratch/run.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp05-scratch #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=01:30:00 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err set -euo pipefail cd /orcd/scratch/orcd/013/satra/kwyk_reproduction source /orcd/data/satra/002/projects/nobrainer/venvs/nobrainer/bin/activate echo "=== Experiment 05: KWYKMeshNet from scratch ===" python experiments/05_20260330_kwyk_from_scratch/run.py echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/README.md ================================================ # Experiment 06: Full-volume (256³) training with augmentation ## Rationale Current training uses 32³ patches — the model never sees global context. Training on full 256³ volumes should improve segmentation quality, especially for large structures. Combined with augmentation (affine + flip + noise) for regularization. ## Plan 1. Use 128³ blocks on L40S (batch_size=4 fits in 47GB) OR request H200/A100 for full 256³ (batch_size=1 per GPU, 2 GPUs) 2. Standard augmentation profile (affine rotation/scale, flips, Gaussian noise) 3. MeshNet first (deterministic baseline), then bwn_multi 4. 20 epochs on 500 subjects 5. Compare Dice vs 32³ patch training ## GPU Options | GPU | Memory | 256³ batch=1 | 128³ batch=4 | |-----|--------|-------------|-------------| | L40S (47GB) | 47GB | OOM (90GB) | OK (11GB) | | A100 (80GB) | 80GB | Tight | OK | | H200 (141GB) | 141GB | OK | OK | | 2× L40S | 94GB | OK | OK | ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/config_256.yaml ================================================ # Full 256³ volume training — requires H200 or multi-GPU filters: 96 receptive_field: 37 dropout_rate: 0.25 sigma_init: 0.0001 block_shape: [256, 256, 256] lr: 0.0001 batch_size: 1 n_classes: 50 label_mapping: 50-class patches_per_volume: 1 # whole volume = 1 patch pretrain_epochs: 20 bayesian_epochs: 20 val_dice_freq: 5 n_samples: 10 augmentation_profile: standard gradient_checkpointing: true zarr_store: /orcd/scratch/orcd/013/satra/kwyk_reproduction/data/kwyk_500.zarr datasets: - kwyk split: [80, 10, 10] variants: bwn_multi: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: true ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/config_256_mp.yaml ================================================ # Full 256³ volume — model parallel across 2 GPUs filters: 96 receptive_field: 37 dropout_rate: 0.25 sigma_init: 0.0001 block_shape: [256, 256, 256] lr: 0.0001 batch_size: 1 n_classes: 50 label_mapping: 50-class patches_per_volume: 1 pretrain_epochs: 20 bayesian_epochs: 20 val_dice_freq: 5 n_samples: 10 model_parallel: true zarr_store: /orcd/scratch/orcd/013/satra/kwyk_reproduction/data/kwyk_500.zarr datasets: - kwyk split: [80, 10, 10] variants: bwn_multi: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: true ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/config_fullvol.yaml ================================================ # Full-volume training with augmentation — 50-class parcellation # Uses 128³ blocks (fits on L40S) or 256³ on larger GPUs filters: 96 receptive_field: 37 dropout_rate: 0.25 sigma_init: 0.0001 block_shape: [128, 128, 128] lr: 0.0001 batch_size: 4 n_classes: 50 label_mapping: 50-class patches_per_volume: 8 # fewer patches needed with large blocks pretrain_epochs: 20 bayesian_epochs: 20 val_dice_freq: 5 n_samples: 10 augmentation_profile: standard datasets: - kwyk split: [80, 10, 10] variants: bwn_multi: model: kwyk_meshnet dropout_type: bernoulli mc_at_inference: true ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/run_128.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp06-128 #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=8 #SBATCH --mem=64G #SBATCH --time=06:00:00 #SBATCH --output=slurm-128-%j.out #SBATCH --error=slurm-128-%j.err # # Experiment 06a: 128³ blocks + augmentation on L40S # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/06_20260331_fullvol_augment" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="$EXP_DIR/config_fullvol.yaml" MANIFEST="$WORK_DIR/kwyk_manifest_500.csv" ZARR_STORE="$WORK_DIR/data/kwyk_500.zarr" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 06a: 128³ + augmentation ===" echo "Node: $(hostname)" python -c "import torch; print(f'GPU: {torch.cuda.get_device_name(0)}, {torch.cuda.get_device_properties(0).total_memory/1e9:.0f}GB')" # MeshNet 128³ (20 epochs) echo "=== Step 1: MeshNet 128³ ===" python 02_train_meshnet.py --manifest "$MANIFEST" --config "$CONFIG" \ --output-dir "$EXP_DIR/checkpoints/meshnet_128" --epochs 20 # bwn_multi 128³ (20 epochs, warm-start) echo "=== Step 2: bwn_multi 128³ ===" python 03_train_bayesian.py --manifest "$MANIFEST" --config "$CONFIG" \ --variant bwn_multi --warmstart "$EXP_DIR/checkpoints/meshnet_128" \ --output-dir "$EXP_DIR/checkpoints/bwn_multi_128" --epochs 20 # Evaluate echo "=== Step 3: Evaluate ===" for v in meshnet_128 bwn_multi_128; do if [ -f "$EXP_DIR/checkpoints/$v/model.pth" ]; then echo "--- $v (deterministic) ---" python 04_evaluate.py --model "$EXP_DIR/checkpoints/$v" \ --manifest "$MANIFEST" --config "$CONFIG" \ --split test --n-samples 0 \ --output-dir "$EXP_DIR/results/${v}_det" fi done echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/run_256.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp06-256 #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:h200:1 #SBATCH --cpus-per-task=8 #SBATCH --mem=128G #SBATCH --time=06:00:00 #SBATCH --output=slurm-256-%j.out #SBATCH --error=slurm-256-%j.err # # Experiment 06b: Full 256³ volume + augmentation on H200 # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/06_20260331_fullvol_augment" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="$EXP_DIR/config_256.yaml" MANIFEST="$WORK_DIR/kwyk_manifest_500.csv" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 06b: Full 256³ + augmentation ===" echo "Node: $(hostname)" python -c "import torch; print(f'GPU: {torch.cuda.get_device_name(0)}, {torch.cuda.get_device_properties(0).total_memory/1e9:.0f}GB')" # MeshNet 256³ (20 epochs) echo "=== Step 1: MeshNet 256³ ===" python 02_train_meshnet.py --manifest "$MANIFEST" --config "$CONFIG" \ --output-dir "$EXP_DIR/checkpoints/meshnet_256" --epochs 20 # bwn_multi 256³ (20 epochs, warm-start) echo "=== Step 2: bwn_multi 256³ ===" python 03_train_bayesian.py --manifest "$MANIFEST" --config "$CONFIG" \ --variant bwn_multi --warmstart "$EXP_DIR/checkpoints/meshnet_256" \ --output-dir "$EXP_DIR/checkpoints/bwn_multi_256" --epochs 20 # Evaluate echo "=== Step 3: Evaluate ===" for v in meshnet_256 bwn_multi_256; do if [ -f "$EXP_DIR/checkpoints/$v/model.pth" ]; then echo "--- $v (deterministic) ---" python 04_evaluate.py --model "$EXP_DIR/checkpoints/$v" \ --manifest "$MANIFEST" --config "$CONFIG" \ --split test --n-samples 0 \ --output-dir "$EXP_DIR/results/${v}_det" fi done echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/run_256_a100.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp06-256 #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:h200:1 #SBATCH --cpus-per-task=8 #SBATCH --mem=128G #SBATCH --time=12:00:00 #SBATCH --output=slurm-256-%j.out #SBATCH --error=slurm-256-%j.err # # Experiment 06b: Full 256³ volume + augmentation on H200 (141GB) # Single H200 fits 256³ with batch=1 (~90GB forward+backward) # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/06_20260331_fullvol_augment" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="$EXP_DIR/config_256.yaml" MANIFEST="$WORK_DIR/kwyk_manifest_500.csv" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 06b: Full 256³ + augmentation (2× A100) ===" echo "Node: $(hostname)" python -c " import torch n = torch.cuda.device_count() for i in range(n): print(f'GPU {i}: {torch.cuda.get_device_name(i)}, {torch.cuda.get_device_properties(i).total_memory/1e9:.0f}GB') " # MeshNet 256³ (20 epochs) echo "=== Step 1: MeshNet 256³ ===" python 02_train_meshnet.py --manifest "$MANIFEST" --config "$CONFIG" \ --output-dir "$EXP_DIR/checkpoints/meshnet_256" --epochs 20 # bwn_multi 256³ (20 epochs, warm-start) echo "=== Step 2: bwn_multi 256³ ===" python 03_train_bayesian.py --manifest "$MANIFEST" --config "$CONFIG" \ --variant bwn_multi --warmstart "$EXP_DIR/checkpoints/meshnet_256" \ --output-dir "$EXP_DIR/checkpoints/bwn_multi_256" --epochs 20 # Evaluate echo "=== Step 3: Evaluate ===" for v in meshnet_256 bwn_multi_256; do if [ -f "$EXP_DIR/checkpoints/$v/model.pth" ]; then echo "--- $v (deterministic) ---" python 04_evaluate.py --model "$EXP_DIR/checkpoints/$v" \ --manifest "$MANIFEST" --config "$CONFIG" \ --split test --n-samples 0 \ --output-dir "$EXP_DIR/results/${v}_det" fi done echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/run_256_gradckpt.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp06-gc #SBATCH --partition=pi_satra #SBATCH --gres=gpu:a100:1 #SBATCH --cpus-per-task=8 #SBATCH --mem=128G #SBATCH --time=12:00:00 #SBATCH --output=slurm-256gc-%j.out #SBATCH --error=slurm-256gc-%j.err # # Experiment 06c: Full 256³ + gradient checkpointing on single A100 (80GB) # Gradient checkpointing halves activation memory: ~90GB → ~45GB # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/06_20260331_fullvol_augment" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="$EXP_DIR/config_256.yaml" MANIFEST="$WORK_DIR/kwyk_manifest_500.csv" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 06c: 256³ + gradient checkpointing (A100) ===" echo "Node: $(hostname)" python -c " import torch print(f'GPU: {torch.cuda.get_device_name(0)}, {torch.cuda.get_device_properties(0).total_memory/1e9:.0f}GB') " # MeshNet 256³ with gradient checkpointing echo "=== Step 1: MeshNet 256³ (gradient checkpointing) ===" python -c " import sys; sys.path.insert(0, '.') from nobrainer.processing.dataset import Dataset from nobrainer.processing.segmentation import Segmentation from utils import load_config, setup_logging import torch log = setup_logging('exp06c') config = load_config('$CONFIG') block_shape = tuple(config['block_shape']) n_classes = config['n_classes'] label_mapping = config.get('label_mapping', 'binary') # Load data import csv train_pairs, val_pairs = [], [] with open('$MANIFEST') as f: for row in csv.DictReader(f): p = (row['t1w_path'], row['label_path']) if row['split'] == 'train': train_pairs.append(p) elif row['split'] == 'val': val_pairs.append(p) ds_train = Dataset.from_files(train_pairs, block_shape=block_shape, n_classes=n_classes).batch(1).binarize(label_mapping).streaming(patches_per_volume=1) ds_val = Dataset.from_files(val_pairs[:5], block_shape=block_shape, n_classes=n_classes).batch(1).binarize(label_mapping).streaming(patches_per_volume=1) seg = Segmentation(base_model='meshnet', model_args={'n_classes': n_classes, 'filters': 96, 'receptive_field': 37, 'dropout_rate': 0.25}, checkpoint_filepath='$EXP_DIR/checkpoints/meshnet_256gc') def _log(epoch, logs, model): msg = f'Epoch {epoch+1}/20: loss={logs[\"loss\"]:.4f}' if 'val_loss' in logs: msg += f' val_loss={logs[\"val_loss\"]:.4f} val_acc={logs[\"val_acc\"]:.4f}' log.info(msg) seg.fit(dataset_train=ds_train, dataset_validate=ds_val, epochs=20, optimizer=torch.optim.Adam, opt_args={'lr': 0.0001}, callbacks=[_log], checkpoint_freq=5, gradient_checkpointing=True) log.info('Done. Result: %s', seg._training_result) " echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/06_20260331_fullvol_augment/run_256_mp.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp06-mp #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:2 #SBATCH --cpus-per-task=8 #SBATCH --mem=64G #SBATCH --time=06:00:00 #SBATCH --output=slurm-256mp-%j.out #SBATCH --error=slurm-256mp-%j.err # # Experiment 06d: Full 256³ + model parallelism on 2× L40S # Layers split across GPUs — each GPU holds ~half the model # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/06_20260331_fullvol_augment" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="$EXP_DIR/config_256_mp.yaml" MANIFEST="$WORK_DIR/kwyk_manifest_500.csv" ZARR_STORE="$WORK_DIR/data/kwyk_500.zarr" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 06d: 256³ + model parallel (2× L40S) ===" echo "Node: $(hostname)" python -c " import torch for i in range(torch.cuda.device_count()): print(f'GPU {i}: {torch.cuda.get_device_name(i)}, {torch.cuda.get_device_properties(i).total_memory/1e9:.0f}GB') " # MeshNet 256³ with model parallelism echo "=== Step 1: MeshNet 256³ (model parallel) ===" python 02_train_meshnet.py --manifest "$MANIFEST" --config "$CONFIG" \ --output-dir "$EXP_DIR/checkpoints/meshnet_256mp" --epochs 20 # bwn_multi 256³ with model parallelism echo "=== Step 2: bwn_multi 256³ (model parallel) ===" python 03_train_bayesian.py --manifest "$MANIFEST" --config "$CONFIG" \ --variant bwn_multi --warmstart "$EXP_DIR/checkpoints/meshnet_256mp" \ --output-dir "$EXP_DIR/checkpoints/bwn_multi_256mp" --epochs 20 # Evaluate echo "=== Step 3: Evaluate ===" for v in meshnet_256mp bwn_multi_256mp; do if [ -f "$EXP_DIR/checkpoints/$v/model.pth" ]; then echo "--- $v (deterministic) ---" python 04_evaluate.py --model "$EXP_DIR/checkpoints/$v" \ --manifest "$MANIFEST" --config "$CONFIG" \ --split test --n-samples 0 \ --output-dir "$EXP_DIR/results/${v}_det" fi done echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/07_20260401_ddp_128/config.yaml ================================================ # DDP test: 128³ patches on 2× L40S filters: 96 receptive_field: 37 dropout_rate: 0.25 block_shape: [128, 128, 128] lr: 0.0001 batch_size: 4 n_classes: 50 label_mapping: 50-class patches_per_volume: 8 val_dice_freq: 1 zarr_store: /orcd/scratch/orcd/013/satra/kwyk_reproduction/data/kwyk_500.zarr ================================================ FILE: scripts/kwyk_reproduction/experiments/07_20260401_ddp_128/run.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp07-ddp #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:2 #SBATCH --cpus-per-task=8 #SBATCH --mem=64G #SBATCH --time=01:00:00 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err # # Experiment 07: DDP test on 2× L40S with 128³ patches, 1 epoch # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/07_20260401_ddp_128" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" MANIFEST="$WORK_DIR/kwyk_manifest_500.csv" ZARR_STORE="$WORK_DIR/data/kwyk_500.zarr" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 07: DDP 2× L40S, 128³, 1 epoch ===" echo "Node: $(hostname)" python -c " import torch for i in range(torch.cuda.device_count()): print(f'GPU {i}: {torch.cuda.get_device_name(i)}, {torch.cuda.get_device_properties(i).total_memory/1e9:.0f}GB') " python 02_train_meshnet.py \ --manifest "$MANIFEST" \ --config "$EXP_DIR/config.yaml" \ --output-dir "$EXP_DIR/checkpoints/meshnet" \ --epochs 1 echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/08_20260401_ddp_128_full/config.yaml ================================================ # DDP 128³ on full dataset (11,479 subjects) filters: 96 receptive_field: 37 dropout_rate: 0.25 block_shape: [128, 128, 128] lr: 0.0001 batch_size: 4 n_classes: 50 label_mapping: 50-class patches_per_volume: 1 val_dice_freq: 1 zarr_store: /orcd/scratch/orcd/013/satra/kwyk_reproduction/data/kwyk_full.zarr ================================================ FILE: scripts/kwyk_reproduction/experiments/08_20260401_ddp_128_full/run.sbatch ================================================ #!/bin/bash #SBATCH --job-name=exp08-full #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:2 #SBATCH --cpus-per-task=16 #SBATCH --mem=128G #SBATCH --time=24:00:00 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err #SBATCH --requeue #SBATCH --signal=USR1@120 # # DDP 128³ on full KWYK dataset (11,479 subjects, 2× L40S) # Checkpoints every epoch for resume after preemption # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" EXP_DIR="$WORK_DIR/experiments/08_20260401_ddp_128_full" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="$EXP_DIR/config.yaml" MANIFEST="$WORK_DIR/kwyk_manifest_full.csv" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Experiment 08: DDP 128³ full dataset ===" echo "Node: $(hostname)" echo "Job ID: ${SLURM_JOB_ID:-local}" python -c " import torch for i in range(torch.cuda.device_count()): print(f'GPU {i}: {torch.cuda.get_device_name(i)}, {torch.cuda.get_device_properties(i).total_memory/1e9:.0f}GB') " echo "=== Training MeshNet 128³ (20 epochs, full dataset) ===" python 02_train_meshnet.py \ --manifest "$MANIFEST" \ --config "$CONFIG" \ --output-dir "$EXP_DIR/checkpoints/meshnet_128" \ --epochs 20 echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/experiments/task-planner.md ================================================ # Bayesian Model Learning Experiments — Task Planner **Session dates:** 2026-03-30 to 2026-03-31 ## Root Causes Found ### 1. Warm-start key ordering bug (Exp 03) `sorted(state.items())` puts `classifier.weight` before `encoder.X`. **Fixed:** filter classifier, transfer separately. Branch: `fix/warmstart-key-ordering`. ### 2. Single mc= flag (Exp 05 + original TF code review) Original TF bwn trains with `is_mc_v=False` (deterministic VWN) + `is_mc_b=True` (bernoulli dropout ON). PyTorch had one `mc=` flag controlling both. **Fixed:** `mc_vwn` and `mc_dropout` independent flags. Branch: `fix/kwyk-decouple-mc-flags`. ### 3. Dropout ordering mismatch Original TF: conv → dropout → relu. PyTorch had: conv → relu → dropout. **Fixed** in same branch. ### 4. Data pipeline bottleneck NIfTI streaming (20K reads/epoch) too slow. Zarr3 with sharding (1 file per array, 32³ chunk-aligned reads) eliminates this. **Fixed:** sharded Zarr3 conversion + PatchDataset zarr:// support. ### 5. auto_batch_size profiling with wrong mode Profiled with mc=True (both VWN+dropout) but training uses mc_vwn=False. **Fixed:** forward_kwargs parameter in auto_batch_size. ## Additional Discrepancies Found (from original TF code review) | Aspect | Original TF | Current PyTorch | |---|---|---| | Framework | TensorFlow 1.12 | PyTorch 2.11 | | Dropout rate (bwn) | keep_prob=0.5 (50%) | dropout_rate=0.25 | | Dropout type | tf.nn.dropout (element-wise, no rescale) | nn.Dropout3d (spatial, rescales) | | Loss regularization | L2 weight decay: sum(mu²)/(2*N) | ELBOLoss with KL=0 (just CE) | | sigma_prior (bwn) | 1.0 | 0.1 | | Classifier layer | VWN conv3d | Standard nn.Conv3d | | prior_path support | Yes (load from trained model) | Missing | | Concrete dropout p_prior | 0.5 | 0.9 | | Subjects | ~10,000 | 500 (current) | ## Experiment Log | # | Name | Status | Key Finding | |---|---|---|---| | 01 | Eval det mode | DONE | Zero Dice both modes — weights empty | | 02 | Binary Bayesian | DONE | Zero Dice — same issue | | 03 | Warm-start diagnostic | DONE | **BUG: sorted key ordering** | | 04 | Fixed warm-start | DONE | Transfer improved but training destroys signal | | 05 | From scratch | DONE | **mc=False trains, mc=True doesn't** | | 06 | Original TF code review | DONE | **is_mc_v=False, is_mc_b=True in original** | ## Current Pipeline (running) - Job 11222646: Zarr3 conversion (500 subjects, sharded) - Jobs 11222647-49: Bayesian training (mc_vwn=False, mc_dropout=True, streaming Zarr) - Job 11222650: Evaluation (det + MC) ## Next Steps 1. Match dropout rate to original (0.5 not 0.25) 2. Add L2 weight decay to match original loss (not KL) 3. Scale to more subjects / epochs 4. Implement prior_path for multi-stage training ================================================ FILE: scripts/kwyk_reproduction/label_mappings/115-class-mapping.csv ================================================ original,new,label 0,0,Unknown 2,1,Left-Cerebral-White-Matter 3,2,Left-Cerebral-Cortex 4,3,Left-Lateral-Ventricle 5,4,Left-Inf-Lat-Vent 7,5,Left-Cerebellum-White-Matter 8,6,Left-Cerebellum-Cortex 10,7,Left-Thalamus-Proper 11,8,Left-Caudate 12,9,Left-Putamen 13,10,Left-Pallidum 14,11,3rd-Ventricle 15,12,4th-Ventricle 16,13,Brain-Stem 17,14,Left-Hippocampus 18,15,Left-Amygdala 24,16,CSF 26,17,Left-Accumbens-area 28,18,Left-VentralDC 30,19,Left-vessel 31,20,Left-choroid-plexus 41,21,Right-Cerebral-White-Matter 42,22,Right-Cerebral-Cortex 43,23,Right-Lateral-Ventricle 44,24,Right-Inf-Lat-Vent 46,25,Right-Cerebellum-White-Matter 47,26,Right-Cerebellum-Cortex 49,27,Right-Thalamus-Proper 50,28,Right-Caudate 51,29,Right-Putamen 52,30,Right-Pallidum 53,31,Right-Hippocampus 54,32,Right-Amygdala 58,33,Right-Accumbens-area 60,34,Right-VentralDC 62,35,Right-vessel 63,36,Right-choroid-plexus 72,37,5th-Ventricle 77,38,WM-hypointensities 85,39,Optic-Chiasm 251,40,CC_Posterior 252,41,CC_Mid_Posterior 253,42,CC_Central 254,43,CC_Mid_Anterior 255,44,CC_Anterior 1000,45,ctx-lh-unknown 1001,46,ctx-lh-bankssts 1002,47,ctx-lh-caudalanteriorcingulate 1003,48,ctx-lh-caudalmiddlefrontal 1005,49,ctx-lh-cuneus 1006,50,ctx-lh-entorhinal 1007,51,ctx-lh-fusiform 1008,52,ctx-lh-inferiorparietal 1009,53,ctx-lh-inferiortemporal 1010,54,ctx-lh-isthmuscingulate 1011,55,ctx-lh-lateraloccipital 1012,56,ctx-lh-lateralorbitofrontal 1013,57,ctx-lh-lingual 1014,58,ctx-lh-medialorbitofrontal 1015,59,ctx-lh-middletemporal 1016,60,ctx-lh-parahippocampal 1017,61,ctx-lh-paracentral 1018,62,ctx-lh-parsopercularis 1019,63,ctx-lh-parsorbitalis 1020,64,ctx-lh-parstriangularis 1021,65,ctx-lh-pericalcarine 1022,66,ctx-lh-postcentral 1023,67,ctx-lh-posteriorcingulate 1024,68,ctx-lh-precentral 1025,69,ctx-lh-precuneus 1026,70,ctx-lh-rostralanteriorcingulate 1027,71,ctx-lh-rostralmiddlefrontal 1028,72,ctx-lh-superiorfrontal 1029,73,ctx-lh-superiorparietal 1030,74,ctx-lh-superiortemporal 1031,75,ctx-lh-supramarginal 1032,76,ctx-lh-frontalpole 1033,77,ctx-lh-temporalpole 1034,78,ctx-lh-transversetemporal 1035,79,ctx-lh-insula 2000,80,ctx-rh-unknown 2001,81,ctx-rh-bankssts 2002,82,ctx-rh-caudalanteriorcingulate 2003,83,ctx-rh-caudalmiddlefrontal 2005,84,ctx-rh-cuneus 2006,85,ctx-rh-entorhinal 2007,86,ctx-rh-fusiform 2008,87,ctx-rh-inferiorparietal 2009,88,ctx-rh-inferiortemporal 2010,89,ctx-rh-isthmuscingulate 2011,90,ctx-rh-lateraloccipital 2012,91,ctx-rh-lateralorbitofrontal 2013,92,ctx-rh-lingual 2014,93,ctx-rh-medialorbitofrontal 2015,94,ctx-rh-middletemporal 2016,95,ctx-rh-parahippocampal 2017,96,ctx-rh-paracentral 2018,97,ctx-rh-parsopercularis 2019,98,ctx-rh-parsorbitalis 2020,99,ctx-rh-parstriangularis 2021,100,ctx-rh-pericalcarine 2022,101,ctx-rh-postcentral 2023,102,ctx-rh-posteriorcingulate 2024,103,ctx-rh-precentral 2025,104,ctx-rh-precuneus 2026,105,ctx-rh-rostralanteriorcingulate 2027,106,ctx-rh-rostralmiddlefrontal 2028,107,ctx-rh-superiorfrontal 2029,108,ctx-rh-superiorparietal 2030,109,ctx-rh-superiortemporal 2031,110,ctx-rh-supramarginal 2032,111,ctx-rh-frontalpole 2033,112,ctx-rh-temporalpole 2034,113,ctx-rh-transversetemporal 2035,114,ctx-rh-insula ================================================ FILE: scripts/kwyk_reproduction/label_mappings/50-class-mapping.csv ================================================ ,original,new,label 0,0,0,Unknown 1,2,1,Left-Cerebral-White-Matter 2,4,2,Left-Lateral-Ventricle 3,5,2,Left-Inf-Lat-Vent 4,7,3,Left-Cerebellum-White-Matter 5,8,4,Left-Cerebellum-Cortex 6,10,5,Left-Thalamus-Proper 7,11,6,Left-Caudate 8,12,7,Left-Putamen 9,13,8,Left-Pallidum 10,14,2,3rd-Ventricle 11,15,2,4th-Ventricle 12,16,9,Brain-Stem 13,17,10,Left-Hippocampus 14,18,11,Left-Amygdala 15,24,12,CSF 16,26,13,Left-Accumbens-area 17,28,14,Left-VentralDC 18,41,1,Right-Cerebral-White-Matter 19,43,2,Right-Lateral-Ventricle 20,44,2,Right-Inf-Lat-Vent 21,46,3,Right-Cerebellum-White-Matter 22,47,4,Right-Cerebellum-Cortex 23,49,5,Right-Thalamus-Proper 24,50,6,Right-Caudate 25,51,7,Right-Putamen 26,52,8,Right-Pallidum 27,53,10,Right-Hippocampus 28,54,11,Right-Amygdala 29,58,13,Right-Accumbens-area 30,60,14,Right-VentralDC 31,72,2,5th-Ventricle 32,192,15,Corpus_Callosum 33,251,15,CC_Posterior 34,252,15,CC_Mid_Posterior 35,253,15,CC_Central 36,254,15,CC_Mid_Anterior 37,255,15,CC_Anterior 38,1001,16,ctx-lh-bankssts 39,1002,17,ctx-lh-caudalanteriorcingulate 40,1003,18,ctx-lh-caudalmiddlefrontal 41,1005,19,ctx-lh-cuneus 42,1006,20,ctx-lh-entorhinal 43,1007,21,ctx-lh-fusiform 44,1008,22,ctx-lh-inferiorparietal 45,1009,23,ctx-lh-inferiortemporal 46,1010,24,ctx-lh-isthmuscingulate 47,1011,25,ctx-lh-lateraloccipital 48,1012,26,ctx-lh-lateralorbitofrontal 49,1013,27,ctx-lh-lingual 50,1014,28,ctx-lh-medialorbitofrontal 51,1015,29,ctx-lh-middletemporal 52,1016,30,ctx-lh-parahippocampal 53,1017,31,ctx-lh-paracentral 54,1018,32,ctx-lh-parsopercularis 55,1019,33,ctx-lh-parsorbitalis 56,1020,34,ctx-lh-parstriangularis 57,1021,35,ctx-lh-pericalcarine 58,1022,36,ctx-lh-postcentral 59,1023,37,ctx-lh-posteriorcingulate 60,1024,38,ctx-lh-precentral 61,1025,39,ctx-lh-precuneus 62,1026,40,ctx-lh-rostralanteriorcingulate 63,1027,41,ctx-lh-rostralmiddlefrontal 64,1028,42,ctx-lh-superiorfrontal 65,1029,43,ctx-lh-superiorparietal 66,1030,44,ctx-lh-superiortemporal 67,1031,45,ctx-lh-supramarginal 68,1032,46,ctx-lh-frontalpole 69,1033,47,ctx-lh-temporalpole 70,1034,48,ctx-lh-transversetemporal 71,1035,49,ctx-lh-insula 72,2001,16,ctx-rh-bankssts 73,2002,17,ctx-rh-caudalanteriorcingulate 74,2003,18,ctx-rh-caudalmiddlefrontal 75,2005,19,ctx-rh-cuneus 76,2006,20,ctx-rh-entorhinal 77,2007,21,ctx-rh-fusiform 78,2008,22,ctx-rh-inferiorparietal 79,2009,23,ctx-rh-inferiortemporal 80,2010,24,ctx-rh-isthmuscingulate 81,2011,25,ctx-rh-lateraloccipital 82,2012,26,ctx-rh-lateralorbitofrontal 83,2013,27,ctx-rh-lingual 84,2014,28,ctx-rh-medialorbitofrontal 85,2015,29,ctx-rh-middletemporal 86,2016,30,ctx-rh-parahippocampal 87,2017,31,ctx-rh-paracentral 88,2018,32,ctx-rh-parsopercularis 89,2019,33,ctx-rh-parsorbitalis 90,2020,34,ctx-rh-parstriangularis 91,2021,35,ctx-rh-pericalcarine 92,2022,36,ctx-rh-postcentral 93,2023,37,ctx-rh-posteriorcingulate 94,2024,38,ctx-rh-precentral 95,2025,39,ctx-rh-precuneus 96,2026,40,ctx-rh-rostralanteriorcingulate 97,2027,41,ctx-rh-rostralmiddlefrontal 98,2028,42,ctx-rh-superiorfrontal 99,2029,43,ctx-rh-superiorparietal 100,2030,44,ctx-rh-superiortemporal 101,2031,45,ctx-rh-supramarginal 102,2032,46,ctx-rh-frontalpole 103,2033,47,ctx-rh-temporalpole 104,2034,48,ctx-rh-transversetemporal 105,2035,49,ctx-rh-insul ================================================ FILE: scripts/kwyk_reproduction/label_mappings/6-class-mapping.csv ================================================ ,original,new,label,50-class 0,0,0,Unknown,0 1,2,1,Left-Cerebral-White-Matter,1 2,4,3,Left-Lateral-Ventricle,2 3,5,3,Left-Inf-Lat-Vent,2 4,7,1,Left-Cerebellum-White-Matter,3 5,8,2,Left-Cerebellum-Cortex,4 6,10,4,Left-Thalamus-Proper,5 7,11,4,Left-Caudate,6 8,12,4,Left-Putamen,7 9,13,4,Left-Pallidum,8 10,14,3,3rd-Ventricle,2 11,15,3,4th-Ventricle,2 12,16,5,Brain-Stem,9 13,17,4,Left-Hippocampus,10 14,18,4,Left-Amygdala,11 15,24,3,CSF,12 16,26,4,Left-Accumbens-area,13 17,28,4,Left-VentralDC,14 18,41,1,Right-Cerebral-White-Matter,1 19,43,3,Right-Lateral-Ventricle,2 20,44,3,Right-Inf-Lat-Vent,2 21,46,1,Right-Cerebellum-White-Matter,3 22,47,2,Right-Cerebellum-Cortex,4 23,49,4,Right-Thalamus-Proper,5 24,50,4,Right-Caudate,6 25,51,4,Right-Putamen,7 26,52,4,Right-Pallidum,8 27,53,4,Right-Hippocampus,10 28,54,4,Right-Amygdala,11 29,58,4,Right-Accumbens-area,13 30,60,4,Right-VentralDC,14 31,72,3,5th-Ventricle,2 32,192,1,Corpus_Callosum,15 33,251,1,CC_Posterior,15 34,252,1,CC_Mid_Posterior,15 35,253,1,CC_Central,15 36,254,1,CC_Mid_Anterior,15 37,255,1,CC_Anterior,15 38,1001,2,ctx-lh-bankssts,16 39,1002,2,ctx-lh-caudalanteriorcingulate,17 40,1003,2,ctx-lh-caudalmiddlefrontal,18 41,1005,2,ctx-lh-cuneus,19 42,1006,2,ctx-lh-entorhinal,20 43,1007,2,ctx-lh-fusiform,21 44,1008,2,ctx-lh-inferiorparietal,22 45,1009,2,ctx-lh-inferiortemporal,23 46,1010,2,ctx-lh-isthmuscingulate,24 47,1011,2,ctx-lh-lateraloccipital,25 48,1012,2,ctx-lh-lateralorbitofrontal,26 49,1013,2,ctx-lh-lingual,27 50,1014,2,ctx-lh-medialorbitofrontal,28 51,1015,2,ctx-lh-middletemporal,29 52,1016,2,ctx-lh-parahippocampal,30 53,1017,2,ctx-lh-paracentral,31 54,1018,2,ctx-lh-parsopercularis,32 55,1019,2,ctx-lh-parsorbitalis,33 56,1020,2,ctx-lh-parstriangularis,34 57,1021,2,ctx-lh-pericalcarine,35 58,1022,2,ctx-lh-postcentral,36 59,1023,2,ctx-lh-posteriorcingulate,37 60,1024,2,ctx-lh-precentral,38 61,1025,2,ctx-lh-precuneus,39 62,1026,2,ctx-lh-rostralanteriorcingulate,40 63,1027,2,ctx-lh-rostralmiddlefrontal,41 64,1028,2,ctx-lh-superiorfrontal,42 65,1029,2,ctx-lh-superiorparietal,43 66,1030,2,ctx-lh-superiortemporal,44 67,1031,2,ctx-lh-supramarginal,45 68,1032,2,ctx-lh-frontalpole,46 69,1033,2,ctx-lh-temporalpole,47 70,1034,2,ctx-lh-transversetemporal,48 71,1035,2,ctx-lh-insula,49 72,2001,2,ctx-rh-bankssts,16 73,2002,2,ctx-rh-caudalanteriorcingulate,17 74,2003,2,ctx-rh-caudalmiddlefrontal,18 75,2005,2,ctx-rh-cuneus,19 76,2006,2,ctx-rh-entorhinal,20 77,2007,2,ctx-rh-fusiform,21 78,2008,2,ctx-rh-inferiorparietal,22 79,2009,2,ctx-rh-inferiortemporal,23 80,2010,2,ctx-rh-isthmuscingulate,24 81,2011,2,ctx-rh-lateraloccipital,25 82,2012,2,ctx-rh-lateralorbitofrontal,26 83,2013,2,ctx-rh-lingual,27 84,2014,2,ctx-rh-medialorbitofrontal,28 85,2015,2,ctx-rh-middletemporal,29 86,2016,2,ctx-rh-parahippocampal,30 87,2017,2,ctx-rh-paracentral,31 88,2018,2,ctx-rh-parsopercularis,32 89,2019,2,ctx-rh-parsorbitalis,33 90,2020,2,ctx-rh-parstriangularis,34 91,2021,2,ctx-rh-pericalcarine,35 92,2022,2,ctx-rh-postcentral,36 93,2023,2,ctx-rh-posteriorcingulate,37 94,2024,2,ctx-rh-precentral,38 95,2025,2,ctx-rh-precuneus,39 96,2026,2,ctx-rh-rostralanteriorcingulate,40 97,2027,2,ctx-rh-rostralmiddlefrontal,41 98,2028,2,ctx-rh-superiorfrontal,42 99,2029,2,ctx-rh-superiorparietal,43 100,2030,2,ctx-rh-superiortemporal,44 101,2031,2,ctx-rh-supramarginal,45 102,2032,2,ctx-rh-frontalpole,46 103,2033,2,ctx-rh-temporalpole,47 104,2034,2,ctx-rh-transversetemporal,48 105,2035,2,ctx-rh-insul,49 ================================================ FILE: scripts/kwyk_reproduction/run.sh ================================================ #!/bin/bash # KWYK Brain Extraction Reproduction — Full Pipeline Runner # # Usage: # ./run.sh # Full pipeline (data + train + evaluate) # ./run.sh --smoke-test # Quick smoke test (5 volumes, 2 epochs) # ./run.sh --step data # Run only data assembly # ./run.sh --step train # Run only training (deterministic + Bayesian) # ./run.sh --step evaluate # Run only evaluation # ./run.sh --step compare # Run only kwyk comparison # ./run.sh --step sweep # Run only block size sweep # # Environment: # Creates a dedicated venv at .venv-kwyk/ with all dependencies. # Set NOBRAINER_ROOT to override the nobrainer repo location. set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" NOBRAINER_ROOT="${NOBRAINER_ROOT:-$(cd "$SCRIPT_DIR/../.." && pwd)}" VENV_DIR="$SCRIPT_DIR/.venv-kwyk" STEP="${1:---all}" # Colors for output GREEN='\033[0;32m' YELLOW='\033[1;33m' RED='\033[0;31m' NC='\033[0m' log() { echo -e "${GREEN}[kwyk]${NC} $*"; } warn() { echo -e "${YELLOW}[kwyk]${NC} $*"; } err() { echo -e "${RED}[kwyk]${NC} $*" >&2; } # --- Setup venv --- setup_venv() { if [ ! -d "$VENV_DIR" ]; then log "Creating virtual environment at $VENV_DIR" uv venv --python 3.14 "$VENV_DIR" fi log "Installing dependencies..." # shellcheck disable=SC1091 source "$VENV_DIR/bin/activate" uv pip install -e "$NOBRAINER_ROOT[bayesian,zarr,versioning,dev]" \ monai pyro-ppl datalad matplotlib pyyaml scipy nibabel 2>&1 | tail -3 log "Dependencies installed" } # --- Parse arguments --- SMOKE_TEST=false while [[ $# -gt 0 ]]; do case "$1" in --smoke-test) SMOKE_TEST=true shift ;; --step) STEP="$2" shift 2 ;; --all) STEP="--all" shift ;; *) err "Unknown argument: $1" exit 1 ;; esac done setup_venv cd "$SCRIPT_DIR" # --- Smoke test configuration --- if [ "$SMOKE_TEST" = true ]; then log "Running SMOKE TEST (5 volumes, 2 epochs, tiny model)" EXTRA_ARGS="--epochs 2" DATASETS="ds000114" # Use get_data() instead of DataLad for smoke test else EXTRA_ARGS="" DATASETS="ds000114 ds000228 ds002609" fi # --- Step: Data Assembly --- run_data() { log "Step 1: Assembling dataset from OpenNeuro..." python 01_assemble_dataset.py \ --datasets $DATASETS \ --output-csv manifest.csv \ --output-dir data \ --label-mapping binary log "Dataset assembled: $(wc -l < manifest.csv) subjects" } # --- Step: Training --- run_train() { log "Step 2: Training deterministic MeshNet (warm-start foundation)..." python 02_train_meshnet.py \ --manifest manifest.csv \ --config config.yaml \ --output-dir checkpoints/meshnet \ $EXTRA_ARGS log "Deterministic MeshNet trained (bwn / MAP variant)" log "Step 3a: MC Bernoulli dropout variant (bwn_multi)..." python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bwn_multi \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bwn_multi \ $EXTRA_ARGS log "MC Bernoulli dropout variant saved" log "Step 3b: Spike-and-slab dropout variant (bvwn_multi_prior)..." python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bvwn_multi_prior \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bvwn_multi_prior \ $EXTRA_ARGS log "Spike-and-slab dropout variant trained" log "Step 3c: Standard Gaussian Bayesian variant (for comparison)..." python 03_train_bayesian.py \ --manifest manifest.csv \ --config config.yaml \ --variant bayesian_gaussian \ --warmstart checkpoints/meshnet \ --output-dir checkpoints/bayesian_gaussian \ $EXTRA_ARGS log "Gaussian Bayesian variant trained" } # --- Step: Evaluate --- run_evaluate() { log "Step 4: Evaluating all model variants on test set..." if [ -f 04_evaluate.py ]; then for variant_dir in checkpoints/meshnet checkpoints/bwn_multi checkpoints/bvwn_multi_prior checkpoints/bayesian_gaussian; do variant_name=$(basename "$variant_dir") if [ -f "$variant_dir/model.pth" ]; then log " Evaluating $variant_name..." python 04_evaluate.py \ --model "$variant_dir/model.pth" \ --manifest manifest.csv \ --split test \ --n-samples 10 \ --output-dir "results/$variant_name" else warn " Skipping $variant_name (no model.pth found)" fi done else warn "04_evaluate.py not found" fi } # --- Step: Compare --- run_compare() { log "Step 5: Comparing with original kwyk container..." if [ -f scripts/kwyk_reproduction/05_compare_kwyk.py ]; then python 05_compare_kwyk.py \ --new-model checkpoints/bayesian/model.pth \ --kwyk-dir "$NOBRAINER_ROOT/../kwyk" \ --manifest manifest.csv \ --output-dir results/comparison else warn "05_compare_kwyk.py not yet implemented" fi } # --- Step: Block Size Sweep --- run_sweep() { log "Step 6: Block size sweep..." if [ -f scripts/kwyk_reproduction/06_block_size_sweep.py ]; then python 06_block_size_sweep.py \ --manifest manifest.csv \ --block-sizes 32 64 128 \ --output-dir results/sweep else warn "06_block_size_sweep.py not yet implemented" fi } # --- Execute --- case "$STEP" in --all) run_data run_train run_evaluate run_compare run_sweep ;; data) run_data ;; train) run_train ;; evaluate) run_evaluate ;; compare) run_compare ;; sweep) run_sweep ;; *) err "Unknown step: $STEP" err "Available: data, train, evaluate, compare, sweep" exit 1 ;; esac log "Done! Check figures/ and results/ for outputs." ================================================ FILE: scripts/kwyk_reproduction/slurm_convert_zarr.sbatch ================================================ #!/bin/bash #SBATCH --job-name=kwyk-zarr #SBATCH --partition=mit_preemptable #SBATCH --cpus-per-task=8 #SBATCH --mem=64G #SBATCH --time=02:00:00 #SBATCH --output=slurm-zarr-%j.out #SBATCH --error=slurm-zarr-%j.err # # Convert PAC NIfTI dataset to Zarr3 for fast chunk-aligned I/O # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" MANIFEST="kwyk_manifest_500.csv" ZARR_OUT="$WORK_DIR/data/kwyk_500.zarr" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" echo "=== Converting PAC dataset to Zarr3 ===" echo "Started: $(date)" # Build image/label path lists from manifest python -c " import csv images, labels = [], [] with open('${MANIFEST}') as f: for row in csv.DictReader(f): if row['split'] in ('train', 'val', 'test'): images.append(row['t1w_path']) labels.append(row['label_path']) # Write temp files for the CLI with open('/tmp/zarr_images.txt', 'w') as f: f.write('\n'.join(images)) with open('/tmp/zarr_labels.txt', 'w') as f: f.write('\n'.join(labels)) print(f'Volumes: {len(images)}') " # Convert using nobrainer API directly python -c " from nobrainer.datasets.zarr_store import create_zarr_store import csv pairs = [] subject_ids = [] with open('${MANIFEST}') as f: for row in csv.DictReader(f): pairs.append((row['t1w_path'], row['label_path'])) subject_ids.append(row['subject_id']) print(f'Converting {len(pairs)} volumes to Zarr3...') store = create_zarr_store( image_label_pairs=pairs, output_path='${ZARR_OUT}', subject_ids=subject_ids, chunk_shape=(32, 32, 32), conform=True, target_shape=(256, 256, 256), target_voxel_size=(1.0, 1.0, 1.0), ) print(f'Zarr store created: {store}') # Create partition JSON for train/val/test splits import json partitions = {'train': [], 'val': [], 'test': []} with open('${MANIFEST}') as f: for row in csv.DictReader(f): partitions[row['split']].append(row['subject_id']) part_path = '${ZARR_OUT}_partition.json' with open(part_path, 'w') as f: json.dump({'partitions': partitions}, f, indent=2) print(f'Partition file: {part_path}') for k, v in partitions.items(): print(f' {k}: {len(v)} subjects') " echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/slurm_kwyk_bayesian.sbatch ================================================ #!/bin/bash #SBATCH --job-name=kwyk-bayes #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=8 #SBATCH --mem=32G #SBATCH --time=06:00:00 #SBATCH --output=slurm-kwyk-bayes-%j.out #SBATCH --error=slurm-kwyk-bayes-%j.err # # KWYK Smoke Test — Train one Bayesian variant (launched in parallel) # # Usage (via submit_kwyk_smoke.sh, not directly): # sbatch --dependency=afterok:$MESHNET_JOB slurm_kwyk_bayesian.sbatch bwn_multi # set -euo pipefail VARIANT="${1:?Usage: sbatch slurm_kwyk_bayesian.sbatch }" WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="config_kwyk_smoke.yaml" MANIFEST="kwyk_manifest_500.csv" echo "=== KWYK Bayesian Training: ${VARIANT} ===" echo "Job ID: ${SLURM_JOB_ID:-local}" echo "Node: $(hostname)" echo "Started: $(date)" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" EPOCHS="${KWYK_EPOCHS:-20}" echo "=== Training ${VARIANT} (${EPOCHS} epochs) ===" python 03_train_bayesian.py --manifest "$MANIFEST" --config "$CONFIG" \ --variant "$VARIANT" --warmstart checkpoints/kwyk_smoke_meshnet \ --output-dir "checkpoints/kwyk_smoke_${VARIANT}" --epochs "$EPOCHS" echo "=== ${VARIANT} complete: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/slurm_kwyk_evaluate.sbatch ================================================ #!/bin/bash #SBATCH --job-name=kwyk-eval #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=8 #SBATCH --mem=32G #SBATCH --time=04:00:00 #SBATCH --output=slurm-kwyk-eval-%j.out #SBATCH --error=slurm-kwyk-eval-%j.err # # KWYK Smoke Test — Evaluate all variants (runs after all training completes) # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" CONFIG="config_kwyk_smoke.yaml" MANIFEST="kwyk_manifest_500.csv" echo "=== KWYK Evaluation ===" echo "Job ID: ${SLURM_JOB_ID:-local}" echo "Node: $(hostname)" echo "Started: $(date)" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" for v in kwyk_smoke_meshnet kwyk_smoke_bwn_multi kwyk_smoke_bvwn_multi_prior kwyk_smoke_bayesian_gaussian; do if [ -f "checkpoints/$v/model.pth" ]; then echo "--- Evaluating $v (deterministic) ---" python 04_evaluate.py --model "checkpoints/$v" \ --manifest "$MANIFEST" --config "$CONFIG" \ --split test --n-samples 0 \ --output-dir "results/${v}_det" echo "--- Evaluating $v (MC, 3 samples) ---" python 04_evaluate.py --model "checkpoints/$v" \ --manifest "$MANIFEST" --config "$CONFIG" \ --split test --n-samples 3 \ --output-dir "results/${v}_mc" else echo "WARN: checkpoints/$v/model.pth not found, skipping" fi done echo "=== Evaluation complete: $(date) ===" echo "Results:" for csv in results/kwyk_smoke_*/dice_scores.csv; do [ -f "$csv" ] && echo " $csv: $(tail -n +2 "$csv" | wc -l) volumes" done ================================================ FILE: scripts/kwyk_reproduction/slurm_kwyk_smoke.sbatch ================================================ #!/bin/bash #SBATCH --job-name=kwyk-pac-smoke #SBATCH --partition=mit_preemptable #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=8 #SBATCH --mem=32G #SBATCH --time=06:00:00 #SBATCH --output=slurm-kwyk-smoke-%j.out #SBATCH --error=slurm-kwyk-smoke-%j.err # # KWYK Training — Step 1: deterministic MeshNet + manifest build # Steps 2-4 (Bayesian variants) are launched as dependent parallel jobs # by submit_kwyk_smoke.sh # Set KWYK_EPOCHS env var to override (default: 20) # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" DATA_DIR="/orcd/scratch/orcd/013/satra/data/SharedData/segmentation/freesurfer_asegs" CONFIG="config_kwyk_smoke.yaml" MANIFEST="kwyk_manifest_500.csv" N_SUBJECTS=500 echo "=== KWYK PAC Dataset Smoke Test — MeshNet ===" echo "Job ID: ${SLURM_JOB_ID:-local}" echo "Node: $(hostname)" echo "Started: $(date)" cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" python -c " import torch print('PyTorch:', torch.__version__) print('CUDA:', torch.cuda.is_available()) if torch.cuda.is_available(): print('GPU:', torch.cuda.get_device_name(0)) print('Memory:', round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), 'GB') " # --- Build manifest if needed --- if [ ! -f "$MANIFEST" ]; then echo "=== Building manifest (${N_SUBJECTS} subjects) ===" python build_kwyk_manifest.py \ --data-dir "$DATA_DIR" \ --output-csv "$MANIFEST" \ --n-subjects "$N_SUBJECTS" \ --seed 42 fi echo "Split counts:" tail -n +2 "$MANIFEST" | cut -d, -f5 | sort | uniq -c EPOCHS="${KWYK_EPOCHS:-20}" # --- Train deterministic MeshNet --- echo "=== Training deterministic MeshNet (${EPOCHS} epochs) ===" python 02_train_meshnet.py --manifest "$MANIFEST" --config "$CONFIG" \ --output-dir checkpoints/kwyk_smoke_meshnet --epochs "$EPOCHS" echo "=== MeshNet complete: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/slurm_train.sbatch ================================================ #!/bin/bash #SBATCH --job-name=kwyk-train #SBATCH --partition=preemptible #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=24:00:00 #SBATCH --requeue #SBATCH --signal=B:USR1@120 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err # # KWYK Brain Extraction Reproduction — SLURM Preemptible Training # # Trains all kwyk model variants with automatic checkpoint/resume on # preemption and automatic batch size optimization per GPU. # # Usage: # sbatch slurm_train.sbatch # 1 GPU # sbatch --gres=gpu:4 slurm_train.sbatch # 4 GPUs # sbatch --partition=gpu slurm_train.sbatch # Non-preemptible # KWYK_EPOCHS=100 sbatch slurm_train.sbatch # More epochs # # Multi-GPU: # Batch size is auto-optimized per GPU via nobrainer.gpu.auto_batch_size. # Request more GPUs with --gres=gpu:N. # # Environment variables: # KWYK_DATASETS — space-separated OpenNeuro IDs (default: ds000114) # KWYK_EPOCHS — epochs per variant (default: 50) # KWYK_WORK_DIR — working directory (default: $SLURM_SUBMIT_DIR) # KWYK_VENV — path to venv (default: .venv-kwyk) set -euo pipefail WORK_DIR="${KWYK_WORK_DIR:-${SLURM_SUBMIT_DIR:-$(pwd)}}" VENV_DIR="${KWYK_VENV:-${WORK_DIR}/.venv-kwyk}" DATASETS="${KWYK_DATASETS:-ds000114}" EPOCHS="${KWYK_EPOCHS:-50}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" NOBRAINER_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" echo "=== KWYK SLURM Training ===" echo "Job ID: ${SLURM_JOB_ID:-local}" echo "Node: $(hostname)" echo "Partition: ${SLURM_JOB_PARTITION:-unknown}" echo "GPUs: ${SLURM_GPUS_ON_NODE:-unknown}" echo "Restart: ${SLURM_RESTART_COUNT:-0}" cd "$WORK_DIR" # --- Setup venv (first run only) --- if [ ! -d "$VENV_DIR" ]; then uv venv --python 3.14 "$VENV_DIR" fi # shellcheck disable=SC1091 source "${VENV_DIR}/bin/activate" uv pip install -e "${NOBRAINER_ROOT}[bayesian,versioning,dev]" \ monai pyro-ppl datalad matplotlib pyyaml scipy nibabel 2>&1 | tail -3 uv tool install git-annex 2>/dev/null || true # --- Show GPU info --- python -c " from nobrainer.gpu import gpu_info, gpu_count print('GPUs:', gpu_count()) for g in gpu_info(): print(' GPU {id}: {name} ({memory_gb} GB)'.format(**g)) " cd "$SCRIPT_DIR" # --- Step 1: Data --- if [ ! -f manifest.csv ]; then # shellcheck disable=SC2086 python 01_assemble_dataset.py --datasets $DATASETS \ --output-csv manifest.csv --output-dir data --label-mapping binary fi # --- Step 2: Deterministic MeshNet --- python 02_train_meshnet.py --manifest manifest.csv --config config.yaml \ --output-dir checkpoints/meshnet --epochs "$EPOCHS" # --- Step 3: All Bayesian variants (auto batch size, checkpoint/resume) --- for variant in bwn_multi bvwn_multi_prior; do echo "=== Training $variant ===" python 03_train_bayesian.py --manifest manifest.csv --config config.yaml \ --variant "$variant" --warmstart checkpoints/meshnet \ --output-dir "checkpoints/$variant" --epochs "$EPOCHS" done # --- Step 4: Evaluate --- for v in checkpoints/meshnet checkpoints/bwn_multi checkpoints/bvwn_multi_prior; do [ -f "$v/model.pth" ] && python 04_evaluate.py --model "$v/model.pth" \ --manifest manifest.csv --split test --n-samples 10 \ --output-dir "results/$(basename $v)" || true done echo "=== Done: $(ls checkpoints/*/model.pth 2>/dev/null | wc -l) models ===" ================================================ FILE: scripts/kwyk_reproduction/slurm_zarr_array.sbatch ================================================ #!/bin/bash #SBATCH --job-name=zarr-shard #SBATCH --partition=pi_satra #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=01:00:00 #SBATCH --output=slurm-zarr-shard-%A_%a.out #SBATCH --error=slurm-zarr-shard-%A_%a.err #SBATCH --array=0-114 # # Job array: each task writes one shard (100 subjects) to the Zarr store. # 11479 subjects / 100 per shard = 115 shards (0-114) # Task 0 also creates the store. # set -euo pipefail WORK_DIR="/orcd/scratch/orcd/013/satra/kwyk_reproduction" VENV_DIR="/orcd/data/satra/002/projects/nobrainer/venvs/nobrainer" MANIFEST="$WORK_DIR/kwyk_manifest_full.csv" ZARR_OUT="$WORK_DIR/data/kwyk_full.zarr" SUBJECTS_PER_SHARD=100 cd "$WORK_DIR" source "${VENV_DIR}/bin/activate" if [ "$SLURM_ARRAY_TASK_ID" -eq 0 ]; then echo "=== Shard 0: creating store + writing first shard ===" python convert_zarr_shard.py \ --manifest "$MANIFEST" \ --zarr-store "$ZARR_OUT" \ --shard-idx 0 \ --subjects-per-shard "$SUBJECTS_PER_SHARD" \ --create else # Wait briefly for shard 0 to create the store sleep 10 echo "=== Shard ${SLURM_ARRAY_TASK_ID}: writing ===" python convert_zarr_shard.py \ --manifest "$MANIFEST" \ --zarr-store "$ZARR_OUT" \ --shard-idx "$SLURM_ARRAY_TASK_ID" \ --subjects-per-shard "$SUBJECTS_PER_SHARD" fi echo "=== Done: $(date) ===" ================================================ FILE: scripts/kwyk_reproduction/submit_kwyk_smoke.sh ================================================ #!/bin/bash # Submit the full KWYK PAC smoke test pipeline as parallel SLURM jobs: # # Job 1: MeshNet (deterministic warm-start) # Jobs 2-4: 3 Bayesian variants (parallel, depend on Job 1) # Job 5: Evaluate all variants (depends on Jobs 2-4) # set -euo pipefail cd "$(dirname "${BASH_SOURCE[0]}")" echo "=== Submitting KWYK PAC smoke test pipeline ===" # Step 1: MeshNet MESHNET_JOB=$(sbatch --parsable slurm_kwyk_smoke.sbatch) echo "MeshNet: job ${MESHNET_JOB}" # Step 2: Bayesian variants (parallel, depend on MeshNet) BAYES_JOBS="" for variant in bwn_multi bvwn_multi_prior bayesian_gaussian; do JOB=$(sbatch --parsable --dependency=afterok:${MESHNET_JOB} slurm_kwyk_bayesian.sbatch "$variant") echo "${variant}: job ${JOB} (after ${MESHNET_JOB})" BAYES_JOBS="${BAYES_JOBS:+${BAYES_JOBS},}${JOB}" done # Step 3: Evaluate (depends on all Bayesian jobs + MeshNet) EVAL_JOB=$(sbatch --parsable --dependency=afterok:${MESHNET_JOB}:${BAYES_JOBS} slurm_kwyk_evaluate.sbatch) echo "Evaluate: job ${EVAL_JOB} (after all training)" echo "" echo "=== Pipeline submitted ===" echo "Monitor: squeue -u \$USER" ================================================ FILE: scripts/kwyk_reproduction/utils.py ================================================ """Shared utilities for kwyk reproduction experiments.""" from __future__ import annotations import json import logging from pathlib import Path import signal from typing import Any import numpy as np import torch def load_config(path: str | Path) -> dict[str, Any]: """Load a YAML configuration file and return its contents as a dict. Parameters ---------- path : str or Path Path to the YAML file. Returns ------- dict Parsed configuration. """ import yaml path = Path(path) with open(path) as f: return yaml.safe_load(f) def setup_logging(name: str) -> logging.Logger: """Configure and return a logger with timestamped format. Parameters ---------- name : str Logger name (typically ``__name__``). Returns ------- logging.Logger Configured logger instance. """ logger = logging.getLogger(name) if not logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.INFO) return logger def save_figure(fig: Any, path: str | Path) -> None: """Save a matplotlib figure, creating parent directories if needed. Parameters ---------- fig : matplotlib.figure.Figure The figure to save. path : str or Path Destination file path. """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, bbox_inches="tight", dpi=150) def compute_dice(pred: np.ndarray, label: np.ndarray) -> float: """Compute the Dice score between two binary volumes. Parameters ---------- pred : np.ndarray Binary prediction array. label : np.ndarray Binary ground-truth array. Returns ------- float Dice coefficient in [0, 1]. Returns 1.0 when both arrays are empty. """ pred = pred.astype(bool) label = label.astype(bool) intersection = np.logical_and(pred, label).sum() total = pred.sum() + label.sum() if total == 0: return 1.0 return float(2.0 * intersection / total) def apply_label_mapping( label_vol: np.ndarray, mapping_csv: str | Path | None = None ) -> np.ndarray: """Remap FreeSurfer label codes in a volume. When *mapping_csv* is ``None`` the volume is binarised (``(vol > 0).astype(int)``). Otherwise a CSV with columns ``original,new`` is loaded and used to build a lookup table that maps each original code to its new value. Parameters ---------- label_vol : np.ndarray Integer label volume. mapping_csv : str, Path, or None Path to a CSV mapping file. If ``None``, perform binary thresholding. Returns ------- np.ndarray Remapped label volume with the same shape as the input. """ if mapping_csv is None: return (label_vol > 0).astype(int) import csv mapping_csv = Path(mapping_csv) lookup: dict[int, int] = {} with open(mapping_csv) as f: reader = csv.DictReader(f) for row in reader: lookup[int(row["original"])] = int(row["new"]) mapper = np.vectorize(lambda v: lookup.get(v, 0)) return mapper(label_vol) # --------------------------------------------------------------------------- # Checkpoint / resume for SLURM preemptible jobs # --------------------------------------------------------------------------- _logger = logging.getLogger(__name__) class SlurmPreemptionHandler: """Handle SLURM preemption signals for graceful checkpoint-and-exit. SLURM sends SIGUSR1 (or the signal specified by ``--signal``) before killing a preempted job. This handler sets a flag so the training loop can checkpoint and exit cleanly. The ``--requeue`` sbatch flag then re-submits the job, and the training resumes from the checkpoint. Usage:: handler = SlurmPreemptionHandler() for epoch in range(start_epoch, total_epochs): train_one_epoch(...) save_checkpoint(...) if handler.preempted: log.info("Preempted — exiting for requeue") sys.exit(0) """ def __init__(self, sig: int = signal.SIGUSR1) -> None: self.preempted = False self._sig = sig signal.signal(sig, self._handle) _logger.info("SLURM preemption handler registered (signal=%s)", sig.name) def _handle(self, signum: int, frame: Any) -> None: _logger.warning( "Received preemption signal %d — will checkpoint and exit", signum ) self.preempted = True def save_training_checkpoint( checkpoint_dir: Path, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int, metrics: dict[str, Any], ) -> Path: """Save a resumable training checkpoint. Writes ``checkpoint.pt`` containing model weights, optimizer state, epoch number, and accumulated metrics (losses, Dice scores, etc.). Also writes ``checkpoint_meta.json`` with human-readable status. Parameters ---------- checkpoint_dir : Path Directory to save checkpoint files. model : torch.nn.Module Model to checkpoint. optimizer : torch.optim.Optimizer Optimizer to checkpoint (includes momentum, lr schedule state). epoch : int Completed epoch number (0-indexed). metrics : dict Accumulated training metrics to persist across restarts. Returns ------- Path Path to the written checkpoint file. """ checkpoint_dir.mkdir(parents=True, exist_ok=True) ckpt_path = checkpoint_dir / "checkpoint.pt" torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "metrics": metrics, }, ckpt_path, ) # Human-readable metadata meta = { "epoch": epoch, "best_loss": metrics.get("best_loss", None), "train_losses": metrics.get("train_losses", [])[-3:], } with open(checkpoint_dir / "checkpoint_meta.json", "w") as f: json.dump(meta, f, indent=2, default=str) _logger.info("Checkpoint saved: epoch %d → %s", epoch, ckpt_path) return ckpt_path def load_training_checkpoint( checkpoint_dir: Path, model: torch.nn.Module, optimizer: torch.optim.Optimizer | None = None, ) -> tuple[int, dict[str, Any]]: """Load a training checkpoint and return (start_epoch, metrics). Parameters ---------- checkpoint_dir : Path Directory containing ``checkpoint.pt``. model : torch.nn.Module Model to load weights into. optimizer : torch.optim.Optimizer or None Optimizer to restore state into. If None, only model is loaded. Returns ------- start_epoch : int The next epoch to train (checkpoint epoch + 1). metrics : dict Accumulated metrics from previous training. """ ckpt_path = checkpoint_dir / "checkpoint.pt" if not ckpt_path.exists(): _logger.info("No checkpoint found at %s — starting from scratch", ckpt_path) return 0, {} ckpt = torch.load(ckpt_path, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) if optimizer is not None and "optimizer_state_dict" in ckpt: optimizer.load_state_dict(ckpt["optimizer_state_dict"]) start_epoch = ckpt["epoch"] + 1 metrics = ckpt.get("metrics", {}) _logger.info( "Resumed from checkpoint: epoch %d, best_loss=%.6f", ckpt["epoch"], metrics.get("best_loss", float("inf")), ) return start_epoch, metrics ================================================ FILE: scripts/synthseg_evaluation/02_train.py ================================================ #!/usr/bin/env python """Train a segmentation model with real, synthetic, or mixed data. Usage: python 02_train.py --config config.yaml --mode real --model unet python 02_train.py --config config.yaml --mode mixed --model swin_unetr python 02_train.py --config config.yaml --mode synthetic --model kwyk_meshnet """ from __future__ import annotations import argparse import csv import logging from pathlib import Path import yaml logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Train with SynthSeg evaluation") parser.add_argument("--config", default="config.yaml") parser.add_argument("--mode", choices=["real", "synthetic", "mixed"], required=True) parser.add_argument("--model", required=True) parser.add_argument("--manifest", default="manifest.csv") parser.add_argument("--output-dir", default="checkpoints") parser.add_argument("--epochs", type=int, default=None) return parser.parse_args() def load_manifest(path, split): pairs = [] with open(path) as f: for row in csv.DictReader(f): if row["split"] == split: pairs.append((row["t1w_path"], row["label_path"])) return pairs def main(): args = parse_args() with open(args.config) as f: config = yaml.safe_load(f) data_cfg = config["data"] synth_cfg = config["synthseg"] train_cfg = config["training"] epochs = args.epochs or train_cfg["epochs"] n_classes = data_cfg["n_classes"] block_shape = tuple(data_cfg["block_shape"]) batch_size = data_cfg["batch_size"] lr = train_cfg["lr"] label_mapping = data_cfg["label_mapping"] output_dir = Path(args.output_dir) / f"{args.model}_{args.mode}" output_dir.mkdir(parents=True, exist_ok=True) # Load manifest train_pairs = load_manifest(args.manifest, "train") log.info( "Training: mode=%s, model=%s, %d volumes, %d epochs", args.mode, args.model, len(train_pairs), epochs, ) from nobrainer.processing.dataset import Dataset from nobrainer.processing.segmentation import Segmentation # Build dataset based on mode ds = ( Dataset.from_files(train_pairs, block_shape=block_shape, n_classes=n_classes) .batch(batch_size) .binarize(label_mapping) .augment(train_cfg.get("augmentation_profile", "standard")) ) if args.mode == "synthetic" or args.mode == "mixed": from nobrainer.augmentation.synthseg import SynthSegGenerator label_paths = [p[1] for p in train_pairs] gen = SynthSegGenerator( label_paths, n_samples_per_map=synth_cfg["n_samples_per_map"], elastic_std=synth_cfg["elastic_std"], rotation_range=synth_cfg["rotation_range"], scaling_bounds=synth_cfg["scaling_bounds"], flipping=synth_cfg["flipping"], randomize_resolution=synth_cfg["randomize_resolution"], resolution_range=tuple(synth_cfg["resolution_range"]), bias_field_std=synth_cfg["bias_field_std"], noise_std=synth_cfg["noise_std"], intensity_prior=tuple(synth_cfg["intensity_prior"]), std_prior=tuple(synth_cfg["std_prior"]), ) if args.mode == "mixed": ds = ds.mix(gen, ratio=train_cfg["mixed_ratio"]) log.info("Mixed mode: %.0f%% synthetic", train_cfg["mixed_ratio"] * 100) # Build model model_args = {"n_classes": n_classes} if args.model in ("swin_unetr", "segresnet"): model_args["feature_size"] = 12 if args.model == "swin_unetr" else 16 seg = Segmentation( args.model, model_args=model_args, checkpoint_filepath=str(output_dir) ) # Experiment tracking from nobrainer.experiment import ExperimentTracker tracker = ExperimentTracker( output_dir=output_dir, config={ "mode": args.mode, "model": args.model, "epochs": epochs, "n_classes": n_classes, "batch_size": batch_size, }, project="synthseg-evaluation", name=f"{args.model}_{args.mode}", ) import torch seg.fit( ds, epochs=epochs, optimizer=torch.optim.Adam, opt_args={"lr": lr}, callbacks=[tracker.callback(mode=args.mode, model=args.model)], ) seg.save(output_dir) tracker.finish() log.info("Model saved to %s", output_dir) if __name__ == "__main__": main() ================================================ FILE: scripts/synthseg_evaluation/03_evaluate.py ================================================ #!/usr/bin/env python """Evaluate a trained model with per-class Dice scoring. Reuses the evaluation logic from the kwyk reproduction pipeline. Usage: python 03_evaluate.py --model checkpoints/unet_real --manifest manifest.csv """ from __future__ import annotations import argparse import csv import logging from pathlib import Path import nibabel as nib import numpy as np import yaml logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) def per_class_dice(pred, gt, n_classes): """Compute Dice per class c=1..n_classes-1 (skip background).""" dice = np.zeros(n_classes - 1) for c in range(1, n_classes): p = (pred == c).astype(np.float64) g = (gt == c).astype(np.float64) intersection = (p * g).sum() total = p.sum() + g.sum() dice[c - 1] = 2.0 * intersection / total if total > 0 else 1.0 return dice def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--manifest", required=True) parser.add_argument("--config", default="config.yaml") parser.add_argument("--split", default="test") parser.add_argument("--output-dir", default=None) args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) n_classes = config["data"]["n_classes"] block_shape = tuple(config["data"]["block_shape"]) label_mapping = config["data"]["label_mapping"] output_dir = Path(args.output_dir or args.model) / "eval" output_dir.mkdir(parents=True, exist_ok=True) # Load model from nobrainer.processing.segmentation import Segmentation seg = Segmentation.load(args.model) # Load remap function remap_fn = None if label_mapping and label_mapping != "binary": from nobrainer.processing.dataset import _load_label_mapping remap_fn = _load_label_mapping(label_mapping) # Load test pairs pairs = [] with open(args.manifest) as f: for row in csv.DictReader(f): if row["split"] == args.split: pairs.append((row["t1w_path"], row["label_path"])) log.info("Evaluating %d volumes", len(pairs)) results = [] all_dice = [] for i, (img_path, lbl_path) in enumerate(pairs): gt = np.asarray(nib.load(lbl_path).dataobj, dtype=np.int32) if remap_fn is not None: gt = remap_fn(gt) pred_img = seg.predict(img_path, block_shape=block_shape) pred = np.asarray(pred_img.dataobj, dtype=np.int32) dice = per_class_dice(pred, gt, n_classes) avg = float(dice.mean()) all_dice.append(dice) results.append({"volume": Path(img_path).stem, "avg_dice": avg}) log.info(" %d/%d: %s — Dice=%.4f", i + 1, len(pairs), Path(img_path).stem, avg) # Save results csv_path = output_dir / "dice_scores.csv" with open(csv_path, "w", newline="") as f: w = csv.DictWriter(f, ["volume", "avg_dice"]) w.writeheader() w.writerows(results) np.save(output_dir / "per_class_dice.npy", np.array(all_dice)) avg_dices = [r["avg_dice"] for r in results] log.info("Class Dice: %.4f ± %.4f", np.mean(avg_dices), np.std(avg_dices)) if __name__ == "__main__": main() ================================================ FILE: scripts/synthseg_evaluation/04_compare.py ================================================ #!/usr/bin/env python """Compare results across training modes and models. Usage: python 04_compare.py --results-dir checkpoints/ """ from __future__ import annotations import argparse import csv import logging from pathlib import Path import numpy as np logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser() parser.add_argument("--results-dir", default="checkpoints") parser.add_argument("--output-dir", default="results") args = parser.parse_args() results_dir = Path(args.results_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Scan for eval results: checkpoints/_/eval/dice_scores.csv rows = [] for eval_dir in sorted(results_dir.glob("*/eval")): csv_path = eval_dir / "dice_scores.csv" if not csv_path.exists(): continue name = eval_dir.parent.name # e.g., "unet_real" parts = name.rsplit("_", 1) model = parts[0] if len(parts) == 2 else name mode = parts[1] if len(parts) == 2 else "unknown" with open(csv_path) as f: scores = [float(r["avg_dice"]) for r in csv.DictReader(f)] if scores: rows.append( { "model": model, "mode": mode, "mean_dice": f"{np.mean(scores):.4f}", "std_dice": f"{np.std(scores):.4f}", "n_volumes": len(scores), } ) log.info( "%s (%s): %.4f ± %.4f", model, mode, np.mean(scores), np.std(scores) ) if not rows: log.warning("No results found in %s", results_dir) return # Write comparison table csv_path = output_dir / "comparison_table.csv" with open(csv_path, "w", newline="") as f: w = csv.DictWriter(f, ["model", "mode", "mean_dice", "std_dice", "n_volumes"]) w.writeheader() w.writerows(rows) log.info("Comparison table: %s", csv_path) # Generate figure try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt models = sorted(set(r["model"] for r in rows)) modes = sorted(set(r["mode"] for r in rows)) x = np.arange(len(models)) width = 0.25 fig, ax = plt.subplots(figsize=(max(8, len(models) * 2), 6)) for i, mode in enumerate(modes): means = [] stds = [] for model in models: match = [r for r in rows if r["model"] == model and r["mode"] == mode] if match: means.append(float(match[0]["mean_dice"])) stds.append(float(match[0]["std_dice"])) else: means.append(0) stds.append(0) ax.bar(x + i * width, means, width, yerr=stds, label=mode, capsize=3) ax.set_xlabel("Model") ax.set_ylabel("Mean Class Dice") ax.set_title("SynthSeg Evaluation: Model × Training Mode") ax.set_xticks(x + width * (len(modes) - 1) / 2) ax.set_xticklabels(models) ax.set_ylim(0, 1.05) ax.legend() fig.tight_layout() fig.savefig(output_dir / "comparison_figure.png", dpi=150) plt.close(fig) log.info("Comparison figure: %s", output_dir / "comparison_figure.png") except ImportError: log.warning("matplotlib not available, skipping figure") if __name__ == "__main__": main() ================================================ FILE: scripts/synthseg_evaluation/README.md ================================================ # SynthSeg Evaluation Pipeline Evaluate SynthSeg-based training against real-data baselines using multiple model architectures. ## Quick Start ```bash cd scripts/synthseg_evaluation # Smoke test (2 epochs, unet, real+mixed) ./run.sh --smoke-test # Full evaluation (all models × all modes from config.yaml) ./run.sh ``` ## Training Modes | Mode | Description | |------|-------------| | `real` | Train on real data only (baseline) | | `synthetic` | Train on SynthSeg-generated data only | | `mixed` | Train on mix of real + synthetic (configurable ratio) | ## Available Models | Model | Architecture | Source | |-------|-------------|--------| | `unet` | 3D U-Net | MONAI | | `swin_unetr` | Swin Transformer U-Net | MONAI | | `segresnet` | Residual Encoder SegNet | MONAI | | `kwyk_meshnet` | VWN MeshNet + dropout | nobrainer | | `attention_unet` | Attention U-Net | MONAI | ## Configuration Edit `config.yaml` to change models, training modes, SynthSeg parameters, and data settings. Key options: - `training.modes`: which modes to evaluate - `training.mixed_ratio`: fraction of synthetic data in mixed mode - `models`: list of model architectures to compare - `synthseg.*`: SynthSeg generation parameters ## SLURM ```bash # Single model+mode SYNTHSEG_MODE=mixed SYNTHSEG_MODEL=swin_unetr sbatch slurm_train.sbatch # All combinations for model in unet swin_unetr kwyk_meshnet; do for mode in real synthetic mixed; do SYNTHSEG_MODE=$mode SYNTHSEG_MODEL=$model sbatch slurm_train.sbatch done done ``` ## Output ``` results/ ├── comparison_table.csv # Dice per model × mode └── comparison_figure.png # Bar chart visualization checkpoints/ ├── unet_real/eval/ # Per-model eval results ├── unet_mixed/eval/ ├── swin_unetr_real/eval/ └── ... ``` ================================================ FILE: scripts/synthseg_evaluation/config.yaml ================================================ # SynthSeg Evaluation Pipeline Configuration data: datasets: [ds000114] n_classes: 50 label_mapping: 50-class block_shape: [32, 32, 32] batch_size: 32 split: [80, 10, 10] synthseg: n_samples_per_map: 20 elastic_std: 4.0 rotation_range: 15.0 scaling_bounds: 0.2 flipping: true randomize_resolution: true resolution_range: [1.0, 3.0] bias_field_std: 0.7 noise_std: 0.1 intensity_prior: [0, 250] std_prior: [0, 35] training: modes: [real, synthetic, mixed] mixed_ratio: 0.3 epochs: 50 lr: 0.0001 augmentation_profile: standard models: - unet - swin_unetr - segresnet - segformer3d - kwyk_meshnet evaluation: n_samples: 10 metrics: [per_class_dice, mean_dice] smoke_test: epochs: 2 n_samples_per_map: 2 batch_size: 2 block_shape: [16, 16, 16] models: [unet] modes: [real, mixed] ================================================ FILE: scripts/synthseg_evaluation/run.sh ================================================ #!/bin/bash # SynthSeg Evaluation Pipeline Orchestrator # # Usage: # ./run.sh --smoke-test # Quick test (2 epochs, 1 model) # ./run.sh # Full evaluation # ./run.sh --config custom.yaml # Custom config set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" CONFIG="${1:-config.yaml}" SMOKE=false while [[ $# -gt 0 ]]; do case "$1" in --smoke-test) SMOKE=true; shift ;; --config) CONFIG="$2"; shift 2 ;; *) shift ;; esac done cd "$SCRIPT_DIR" echo "=== SynthSeg Evaluation Pipeline ===" echo "Config: $CONFIG" echo "Smoke test: $SMOKE" # Use sample data if no manifest exists if [ ! -f manifest.csv ]; then echo "=== Creating manifest from sample data ===" python -c " import csv from nobrainer.utils import get_data src = get_data() pairs = [] with open(src) as f: r = csv.reader(f); next(r) pairs = list(r)[:5] splits = ['train','train','train','val','test'] with open('manifest.csv', 'w', newline='') as f: w = csv.DictWriter(f, ['t1w_path','label_path','split']); w.writeheader() for i,(t1,lbl) in enumerate(pairs): w.writerow(dict(t1w_path=t1, label_path=lbl, split=splits[i])) print('Manifest created with', len(pairs), 'volumes') " fi if [ "$SMOKE" = true ]; then echo "=== Smoke test: 2 epochs, unet, real+mixed ===" for mode in real mixed; do echo " Training unet ($mode)..." python 02_train.py --config "$CONFIG" --mode "$mode" --model unet \ --epochs 2 --manifest manifest.csv done for mode in real mixed; do echo " Evaluating unet ($mode)..." python 03_evaluate.py --model "checkpoints/unet_${mode}" \ --manifest manifest.csv --config "$CONFIG" || true done python 04_compare.py --results-dir checkpoints/ --output-dir results/ || true else # Full evaluation from config MODELS=$(python -c "import yaml; c=yaml.safe_load(open('$CONFIG')); print(' '.join(c['training']['modes']))") MODES=$(python -c "import yaml; c=yaml.safe_load(open('$CONFIG')); print(' '.join(c['models']))") for model in $MODES; do for mode in $MODELS; do echo "=== Training $model ($mode) ===" python 02_train.py --config "$CONFIG" --mode "$mode" --model "$model" \ --manifest manifest.csv done done for model in $MODES; do for mode in $MODELS; do echo "=== Evaluating $model ($mode) ===" python 03_evaluate.py --model "checkpoints/${model}_${mode}" \ --manifest manifest.csv --config "$CONFIG" || true done done python 04_compare.py --results-dir checkpoints/ --output-dir results/ fi echo "=== Done. Results in results/ ===" ================================================ FILE: scripts/synthseg_evaluation/slurm_train.sbatch ================================================ #!/bin/bash #SBATCH --job-name=synthseg-eval #SBATCH --partition=preemptible #SBATCH --gres=gpu:1 #SBATCH --cpus-per-task=4 #SBATCH --mem=32G #SBATCH --time=24:00:00 #SBATCH --requeue #SBATCH --signal=B:USR1@120 #SBATCH --output=slurm-%j.out #SBATCH --error=slurm-%j.err # # SynthSeg Evaluation — SLURM Preemptible Training # # Usage: # sbatch slurm_train.sbatch # SYNTHSEG_MODE=mixed SYNTHSEG_MODEL=swin_unetr sbatch slurm_train.sbatch set -euo pipefail MODE="${SYNTHSEG_MODE:-real}" MODEL="${SYNTHSEG_MODEL:-unet}" CONFIG="${SYNTHSEG_CONFIG:-config.yaml}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" echo "=== SynthSeg SLURM Training ===" echo "Model: $MODEL, Mode: $MODE" echo "Job: ${SLURM_JOB_ID:-local}, Restart: ${SLURM_RESTART_COUNT:-0}" cd "$SCRIPT_DIR" python 02_train.py --config "$CONFIG" --mode "$MODE" --model "$MODEL" \ --manifest manifest.csv python 03_evaluate.py --model "checkpoints/${MODEL}_${MODE}" \ --manifest manifest.csv --config "$CONFIG"