Full Code of HansBambel/SmaAt-UNet for AI

master 3e48dcca3e88 cached
32 files
108.5 KB
29.7k tokens
140 symbols
1 requests
Download .txt
Repository: HansBambel/SmaAt-UNet
Branch: master
Commit: 3e48dcca3e88
Files: 32
Total size: 108.5 KB

Directory structure:
gitextract_3pkgi8mx/

├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── README.md
├── calc_metrics_test_set.py
├── create_datasets.py
├── metric/
│   ├── __init__.py
│   ├── confusionmatrix.py
│   ├── iou.py
│   ├── metric.py
│   └── precipitation_metrics.py
├── models/
│   ├── SmaAt_UNet.py
│   ├── __init__.py
│   ├── layers.py
│   ├── regression_lightning.py
│   ├── unet_parts.py
│   ├── unet_parts_depthwise_separable.py
│   └── unet_precip_regression_lightning.py
├── plot_examples.ipynb
├── pyproject.toml
├── requirements-dev.txt
├── requirements-full.txt
├── requirements.txt
├── root.py
├── train_SmaAtUNet.py
├── train_precip_lightning.py
└── utils/
    ├── __init__.py
    ├── data_loader_precip.py
    ├── dataset_VOC.py
    ├── dataset_precip.py
    ├── formatting.py
    └── model_classes.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.idea
.vscode
data
runs
lightning
checkpoints
plots
.ruff_cache
.mypy_cache
__pycache__


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
-   repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.5.0
    hooks:
    -   id: check-yaml
    -   id: end-of-file-fixer
    -   id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
  rev: 'v0.9.7'
  hooks:
    - id: ruff
      args: ["--fix"]
    - id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
  rev: v1.15.0
  hooks:
  - id: mypy

================================================
FILE: .python-version
================================================
3.10.11


================================================
FILE: README.md
================================================
# SmaAt-UNet
Code for the Paper "SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture" [Arxiv-link](https://arxiv.org/abs/2007.04417), [Elsevier-link](https://www.sciencedirect.com/science/article/pii/S0167865521000556?via%3Dihub)

Other relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting). 


![SmaAt-UNet](SmaAt-UNet.png)

The proposed SmaAt-UNet can be found in the model-folder under [SmaAt_UNet](models/SmaAt_UNet.py).

**>>>IMPORTANT<<<**

The original Code from the paper can be found in this branch: https://github.com/HansBambel/SmaAt-UNet/tree/snapshot-paper

The current master branch has since upgraded packages and was refactored. Since the exact package-versions differ the experiments may not be 100% reproducible.

If you have problems running the code, feel free to open an issue here on Github.

## Installing dependencies
This project is using [uv](https://docs.astral.sh/uv/) as dependency management. Therefore, installing the required dependencies is as easy as this:
```shell
uv sync --frozen
```

If you have Nvidia GPU(s) that support CUDA, and running the command above have not installed the torch with CUDA support, you can install the relevant dependencies with the following command:
```shell
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
```

In any case a [`requirements.txt`](requirements.txt) file are generated from `uv.lock`, which implies that `uv` must be installed in order to perform the export:
```shell
uv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt
```

A `requirements-dev.txt` file is also generated for development dependencies:
```shell
uv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt
```

A `requirements-full.txt` file is also generated for all dependencies (including development and additional dependencies):
```shell
uv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt
```

Alternatively, users can use the existing [`requirements.txt`](requirements.txt), [`requirements-dev.txt`](requirements-dev.txt) and [`requirements-full.txt`](requirements-full.txt) files.

Basically, only the following requirements are needed for the training:
```
tqdm
torch
lightning
matplotlib
tensorboard
torchsummary
h5py
numpy
```

Additional requirements are needed for the [testing](###Testing):
```
ipykernel
```

---
For the paper we used the [Lightning](https://github.com/Lightning-AI/lightning) -module (PL) which simplifies the training process and allows easy additions of loggers and checkpoint creations.
In order to use PL we created the model [UNetDSAttention](models/unet_precip_regression_lightning.py) whose parent inherits from the pl.LightningModule. This model is the same as the pure PyTorch SmaAt-UNet implementation with the added PL functions.

### Training
An example [training script](train_SmaAtUNet.py) is given for a classification task (PascalVOC).

For training on the precipitation task we used the [train_precip_lightning.py](train_precip_lightning.py) file.
The training will place a checkpoint file for every model in the `default_save_path` `lightning/precip_regression`.
After finishing training place the best models (probably the ones with the lowest validation loss) that you want to compare in another folder in `checkpoints/comparison`.

### Testing
The script [calc_metrics_test_set.py](calc_metrics_test_set.py) will use all the models placed in `checkpoints/comparison` to calculate the MSEs and other metrics such as Precision, Recall, Accuracy, F1, CSI, FAR, HSS.
The results will get saved in a json in the same folder as the models.

The metrics of the persistence model (for now) are only calculated using the script [test_precip_lightning.py](test_precip_lightning.py). This script will also use all models in that folder and calculate the test-losses for the models in addition to the persistence model.
It will get handled by [this issue](https://github.com/HansBambel/SmaAt-UNet/issues/28).

### Plots
Example code for creating similar plots as in the paper can be found in [plot_examples.ipynb](plot_examples.ipynb).

You might need to make your kernel discoverable so that Jupyter can detect it. To do so, run the following command in your terminal:

```shell
uv run ipython kernel install --user --env VIRTUAL_ENV=$(pwd)/.venv --name=smaat_unet --diplay-name="SmaAt-UNet Kernel"
```

### Precipitation dataset
The dataset consists of precipitation maps in 5-minute intervals from 2016-2019 resulting in about 420,000 images.

The dataset is based on radar precipitation maps from the [The Royal Netherlands Meteorological Institute (KNMI)](https://www.knmi.nl/over-het-knmi/about).
The original images were cropped as can be seen in the example below:
![Precip cutout](Precipitation%20map%20Cutout.png)

If you are interested in the dataset that we used please write an e-mail to: k.trebing@alumni.maastrichtuniversity.nl and s.mehrkanoon@uu.nl

The 50% dataset has 4GB in size and the 20% dataset has 16.5GB in size. Use the [create_dataset.py](create_datasets.py) to create the two datasets used from the original dataset.

The dataset is already normalized using a [Min-Max normalization](https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization)).
In order to revert this you need to multiply the images by 47.83; this results in the images showing the mm/5min.

### Citation
```
@article{TREBING2021,
title = {SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture},
journal = {Pattern Recognition Letters},
year = {2021},
issn = {0167-8655},
doi = {https://doi.org/10.1016/j.patrec.2021.01.036},
url = {https://www.sciencedirect.com/science/article/pii/S0167865521000556},
author = {Kevin Trebing and Tomasz Staǹczyk and Siamak Mehrkanoon},
keywords = {Domain adaptation, neural networks, kernel methods, coupling regularization},
abstract = {Weather forecasting is dominated by numerical weather prediction that tries to model accurately the physical properties of the atmosphere. A downside of numerical weather prediction is that it is lacking the ability for short-term forecasts using the latest available information. By using a data-driven neural network approach we show that it is possible to produce an accurate precipitation nowcast. To this end, we propose SmaAt-UNet, an efficient convolutional neural networks-based on the well known UNet architecture equipped with attention modules and depthwise-separable convolutions. We evaluate our approaches on a real-life datasets using precipitation maps from the region of the Netherlands and binary images of cloud coverage of France. The experimental results show that in terms of prediction performance, the proposed model is comparable to other examined models while only using a quarter of the trainable parameters.}
}

```

### Other relevant publications

Your project might also benefit from exploring our other relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting). Please consider citing them if you find them useful.



================================================
FILE: calc_metrics_test_set.py
================================================
import argparse
from argparse import Namespace
import json
import csv
import numpy as np
import os
from pprint import pprint
import torch
from tqdm import tqdm
import lightning.pytorch as pl
import matplotlib.pyplot as plt

from metric.precipitation_metrics import PrecipitationMetrics
from models.unet_precip_regression_lightning import PersistenceModel
from root import ROOT_DIR
from utils import dataset_precip, model_classes


def parse_args():
    parser = argparse.ArgumentParser(description="Calculate metrics for precipitation models")

    parser.add_argument("--model-folder", type=str, default=None,
                        help="Path to model folder (default: ROOT_DIR/checkpoints/comparison)")
    parser.add_argument("--threshold", type=float, default=0.5, help="Precipitation threshold)")
    parser.add_argument("--normalized", action="store_false", dest="denormalize", default=True,
                        help="Calculate metrics for normalized data")
    parser.add_argument("--load-metrics", action="store_true",
                        help="Load existing metrics instead of running experiments")
    parser.add_argument("--file-type", type=str, default="txt", choices=["json", "txt", "csv"],
                        help="Output file format (json, txt, or csv)")
    parser.add_argument("--save-dir", type=str, default=None,
                        help="Path to save the metrics")
    parser.add_argument("--plot", action="store_true", help="Plot the metrics")
    
    return parser.parse_args()


def convert_tensors_to_python(obj):
    """Convert PyTorch tensors to Python native types recursively."""
    if isinstance(obj, torch.Tensor):
        # Convert tensor to float/int, handling nan values
        value = obj.item() if obj.numel() == 1 else obj.tolist()
        return float('nan') if isinstance(value, float) and np.isnan(value) else value
    elif isinstance(obj, dict):
        return {key: convert_tensors_to_python(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_tensors_to_python(value) for value in obj]
    return obj


def save_metrics(results, file_path, file_type="json"):
    """
    Save metrics to a file in the specified format.
    
    Args:
        results (dict): Dictionary containing model metrics
        file_path (Path): Path to save the file
        file_type (str): File format - "json", "txt", or "csv"
    """
    with open(file_path, "w") as f:
        if file_type == "csv":
            writer = csv.writer(f)
            # Write header row
            if results:
                first_model = next(iter(results.values()))
                writer.writerow(["Model"] + list(first_model.keys()))
                # Write data rows
                for model_name, metrics in results.items():
                    writer.writerow([model_name] + list(metrics.values()))
        else:
            # For json and txt formats, use json.dump
            json.dump(results, f, indent=4)


def run_experiments(model_folder, data_file, threshold=0.5, denormalize=True):
    """
    Run test experiments for all models in the model folder.

    Args:
        model_folder (Path): Path to the model folder
        data_file (Path): Path to the data file
        threshold (float): Threshold for the precipitation

    Returns:
        dict: Dictionary containing model metrics
    """
    results = {}

    # Load the dataset
    dataset = dataset_precip.precipitation_maps_oversampled_h5(
        in_file=data_file, num_input_images=12, num_output_images=6, train=False
    )

    # Create dataloader
    test_dl = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False, pin_memory=True, persistent_workers=True)

    # Create trainer with device specification
    trainer = pl.Trainer(logger=False, enable_checkpointing=False)


    # Find all models in the model folder
    models = ["PersistenceModel"] + [f for f in os.listdir(model_folder) if f.endswith(".ckpt")]
    if not models:
        raise ValueError("No checkpoint files found in the model folder.")

    for model_file in tqdm(models, desc="Models", leave=True):

        model, model_name = model_classes.get_model_class(model_file)
        print(f"Loading model: {model_name}")

        if model_file == "PersistenceModel":
            loaded_model = model(Namespace())
        else:
            loaded_model = model.load_from_checkpoint(os.path.join(model_folder, model_file))
        
        # Override the test metrics, required for testing with normalized data
        loaded_model.test_metrics = PrecipitationMetrics(threshold=threshold, denormalize=denormalize)

        results[model_name] = trainer.test(model=loaded_model, dataloaders=[test_dl])[0]

    return results


def plot_metrics(results, save_path=""):
    """
    Plot all metrics from the results dictionary.
    
    Args:
        results (dict): Dictionary containing metrics for different models
        save_path (str): Path to save the plots
    """
    model_names = list(results.keys())
    metrics = list(results[model_names[0]].keys())
    
    for metric_name in metrics:
        # Get values, replacing NaN with None to skip them in the plot
        values = []
        valid_names = []
        
        for model_name in model_names:
            val = results[model_name][metric_name]
            # Skip NaN values
            if not (isinstance(val, float) and np.isnan(val)):
                values.append(val)
                valid_names.append(model_name)
        
        if len(valid_names) == 0:
            continue
            
        plt.figure(figsize=(10, 6))
        plt.bar(valid_names, values)
        plt.xticks(rotation=45)
        plt.xlabel("Models")
        plt.ylabel(f"{metric_name.upper()}")
        plt.title(f"Comparison of different models - {metric_name.upper()}")
        plt.tight_layout()
        
        if save_path:
            plt.savefig(f"{save_path}/{metric_name}.png")
            
        plt.show()

    
def load_metrics_from_file(file_path, file_type):
    """
    Load metrics from a file in the specified format.
    
    Args:
        file_path (Path): Path to the metrics file
        file_type (str): File format - "json", "txt", or "csv"
        
    Returns:
        dict: Dictionary containing model metrics
    """
    results = {}
    
    with open(file_path) as f:
        if file_type == "json" or file_type == "txt":
            # Both json and txt files are saved in JSON format in this application
            results = json.load(f)
        elif file_type == "csv":
            reader = csv.reader(f)
            headers = next(reader)  # Get header row
            for row in reader:
                if len(row) > 0:
                    model_name = row[0]
                    # Create dictionary of metrics for this model
                    model_metrics = {}
                    for i in range(1, len(headers)):
                        if i < len(row):
                            # Try to convert to float if possible
                            try:
                                value = float(row[i])
                            except (ValueError, TypeError):
                                value = row[i]
                            model_metrics[headers[i]] = value
                    results[model_name] = model_metrics
    
    return results


def main():
    # Parse command line arguments
    args = parse_args()
    
    # Variables from command line arguments
    load_metrics = args.load_metrics
    threshold = args.threshold
    file_type = args.file_type
    denormalize = args.denormalize

    # Define paths
    model_folder = ROOT_DIR / "checkpoints" / "comparison" if args.model_folder is None else args.model_folder
    data_file = ROOT_DIR / "data" / "precipitation" / f"train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_{int(threshold*100)}.h5"
    save_dir = model_folder / f"calc_metrics_test_set_results_{'denorm' if denormalize else 'norm'}" if args.save_dir is None else args.save_dir
    os.makedirs(save_dir, exist_ok=True)

    file_name = f"model_metrics_{threshold}_{'denorm' if denormalize else 'norm'}.{file_type}"
    file_path = save_dir / file_name

    # Load metrics if available
    if load_metrics:
        print(f"Loading metrics from {file_path}")
        results = load_metrics_from_file(file_path, file_type)
    else:
        # Run experiments
        print(f"Running experiments with {threshold} threshold and {'denormalized' if denormalize else 'normalized'} data")
        results = run_experiments(model_folder, data_file, threshold=threshold, denormalize=denormalize)

        # Convert tensors to Python native types
        results = convert_tensors_to_python(results)

        save_metrics(results, file_path, file_type)

    # Display results
    print(f"Metrics saved to {file_path}")
    pprint(results)

    # Plot metrics if requested
    if args.plot:
        plots_dir = save_dir / "plots"
        os.makedirs(plots_dir, exist_ok=True)

        plot_metrics(results, plots_dir)
        print(f"Plots saved to {plots_dir}")


if __name__ == "__main__":
    main()

================================================
FILE: create_datasets.py
================================================
import h5py
import numpy as np
from tqdm import tqdm

from root import ROOT_DIR


def create_dataset(input_length: int, image_ahead: int, rain_amount_thresh: float):
    """Create a dataset that has target images containing at least `rain_amount_thresh` (percent) of rain."""

    precipitation_folder = ROOT_DIR / "data" / "precipitation"
    with h5py.File(
        precipitation_folder / "RAD_NL25_RAC_5min_train_test_2016-2019.h5",
        "r",
    ) as orig_f:
        train_images = orig_f["train"]["images"]
        train_timestamps = orig_f["train"]["timestamps"]
        test_images = orig_f["test"]["images"]
        test_timestamps = orig_f["test"]["timestamps"]
        print("Train shape", train_images.shape)
        print("Test shape", test_images.shape)
        imgSize = train_images.shape[1]
        num_pixels = imgSize * imgSize

        filename = (
            precipitation_folder / f"train_test_2016-2019_input-length_{input_length}_img-"
            f"ahead_{image_ahead}_rain-threshold_{int(rain_amount_thresh * 100)}.h5"
        )

        with h5py.File(filename, "w") as f:
            train_set = f.create_group("train")
            test_set = f.create_group("test")
            train_image_dataset = train_set.create_dataset(
                "images",
                shape=(1, input_length + image_ahead, imgSize, imgSize),
                maxshape=(None, input_length + image_ahead, imgSize, imgSize),
                dtype="float32",
                compression="gzip",
                compression_opts=9,
            )
            train_timestamp_dataset = train_set.create_dataset(
                "timestamps",
                shape=(1, input_length + image_ahead, 1),
                maxshape=(None, input_length + image_ahead, 1),
                dtype=h5py.special_dtype(vlen=str),
                compression="gzip",
                compression_opts=9,
            )
            test_image_dataset = test_set.create_dataset(
                "images",
                shape=(1, input_length + image_ahead, imgSize, imgSize),
                maxshape=(None, input_length + image_ahead, imgSize, imgSize),
                dtype="float32",
                compression="gzip",
                compression_opts=9,
            )
            test_timestamp_dataset = test_set.create_dataset(
                "timestamps",
                shape=(1, input_length + image_ahead, 1),
                maxshape=(None, input_length + image_ahead, 1),
                dtype=h5py.special_dtype(vlen=str),
                compression="gzip",
                compression_opts=9,
            )

            origin = [[train_images, train_timestamps], [test_images, test_timestamps]]
            datasets = [
                [train_image_dataset, train_timestamp_dataset],
                [test_image_dataset, test_timestamp_dataset],
            ]
            for origin_id, (images, timestamps) in enumerate(origin):
                image_dataset, timestamp_dataset = datasets[origin_id]
                
                # Pre-calculate all valid indices that meet the rain threshold
                sequence_length = input_length + image_ahead
                valid_indices = []
                
                # Use vectorized operations to find valid indices
                for i in tqdm(range(sequence_length, len(images)), desc="Finding valid indices"):
                    rain_pixels = np.sum(images[i] > 0)
                    if rain_pixels >= num_pixels * rain_amount_thresh:
                        valid_indices.append(i)
                
                # Pre-allocate the final dataset size
                total_sequences = len(valid_indices)
                image_dataset.resize(total_sequences, axis=0)
                timestamp_dataset.resize(total_sequences, axis=0)
                
                # Batch process the sequences
                for idx, i in enumerate(tqdm(valid_indices, desc="Creating sequences")):
                    imgs = images[i - sequence_length : i]
                    timestamps_img = timestamps[i - sequence_length : i]
                    image_dataset[idx] = imgs
                    timestamp_dataset[idx] = timestamps_img


if __name__ == "__main__":
    print("Creating dataset with at least 20% of rain pixel in target image")
    create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.2)
    print("Creating dataset with at least 50% of rain pixel in target image")
    create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.5)


================================================
FILE: metric/__init__.py
================================================
from .confusionmatrix import ConfusionMatrix
from .iou import IoU
from .metric import Metric

__all__ = ["ConfusionMatrix", "IoU", "Metric"]


================================================
FILE: metric/confusionmatrix.py
================================================
import numpy as np
import torch
from metric import metric


class ConfusionMatrix(metric.Metric):
    """Constructs a confusion matrix for a multi-class classification problems.
    Does not support multi-label, multi-class problems.
    Keyword arguments:
    - num_classes (int): number of classes in the classification problem.
    - normalized (boolean, optional): Determines whether or not the confusion
    matrix is normalized or not. Default: False.
    Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
    """

    def __init__(self, num_classes, normalized=False):
        super().__init__()

        self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
        self.normalized = normalized
        self.num_classes = num_classes
        self.reset()

    def reset(self):
        self.conf.fill(0)

    def add(self, predicted, target):
        """Computes the confusion matrix
        The shape of the confusion matrix is K x K, where K is the number
        of classes.
        Keyword arguments:
        - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
        predicted scores obtained from the model for N examples and K classes,
        or an N-tensor/array of integer values between 0 and K-1.
        - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
        ground-truth classes for N examples and K classes, or an N-tensor/array
        of integer values between 0 and K-1.
        """
        # If target and/or predicted are tensors, convert them to numpy arrays
        if torch.is_tensor(predicted):
            predicted = predicted.cpu().numpy()
        if torch.is_tensor(target):
            target = target.cpu().numpy()

        assert predicted.shape[0] == target.shape[0], "number of targets and predicted outputs do not match"

        if np.ndim(predicted) != 1:
            assert (
                predicted.shape[1] == self.num_classes
            ), "number of predictions does not match size of confusion matrix"
            predicted = np.argmax(predicted, 1)
        else:
            assert (predicted.max() < self.num_classes) and (
                predicted.min() >= 0
            ), "predicted values are not between 0 and k-1"

        if np.ndim(target) != 1:
            assert target.shape[1] == self.num_classes, "Onehot target does not match size of confusion matrix"
            assert (target >= 0).all() and (target <= 1).all(), "in one-hot encoding, target values should be 0 or 1"
            assert (target.sum(1) == 1).all(), "multi-label setting is not supported"
            target = np.argmax(target, 1)
        else:
            assert (target.max() < self.num_classes) and (target.min() >= 0), "target values are not between 0 and k-1"

        # hack for bincounting 2 arrays together
        x = predicted + self.num_classes * target
        bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes**2)
        assert bincount_2d.size == self.num_classes**2
        conf = bincount_2d.reshape((self.num_classes, self.num_classes))

        self.conf += conf

    def value(self):
        """
        Returns:
            Confustion matrix of K rows and K columns, where rows corresponds
            to ground-truth targets and columns corresponds to predicted
            targets.
        """
        if self.normalized:
            conf = self.conf.astype(np.float32)
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return self.conf


================================================
FILE: metric/iou.py
================================================
import numpy as np
from metric import metric
from metric.confusionmatrix import ConfusionMatrix


# Taken from https://github.com/davidtvs/PyTorch-ENet/tree/master/metric
class IoU(metric.Metric):
    """Computes the intersection over union (IoU) per class and corresponding
    mean (mIoU).
    Intersection over union (IoU) is a common evaluation metric for semantic
    segmentation. The predictions are first accumulated in a confusion matrix
    and the IoU is computed from it as follows:
        IoU = true_positive / (true_positive + false_positive + false_negative).
    Keyword arguments:
    - num_classes (int): number of classes in the classification problem
    - normalized (boolean, optional): Determines whether or not the confusion
    matrix is normalized or not. Default: False.
    - ignore_index (int or iterable, optional): Index of the classes to ignore
    when computing the IoU. Can be an int, or any iterable of ints.
    """

    def __init__(self, num_classes, normalized=False, ignore_index=None):
        super().__init__()
        self.conf_metric = ConfusionMatrix(num_classes, normalized)

        if ignore_index is None:
            self.ignore_index = None
        elif isinstance(ignore_index, int):
            self.ignore_index = (ignore_index,)
        else:
            try:
                self.ignore_index = tuple(ignore_index)
            except TypeError as err:
                raise ValueError("'ignore_index' must be an int or iterable") from err

    def reset(self):
        self.conf_metric.reset()

    def add(self, predicted, target):
        """Adds the predicted and target pair to the IoU metric.
        Keyword arguments:
        - predicted (Tensor): Can be a (N, K, H, W) tensor of
        predicted scores obtained from the model for N examples and K classes,
        or (N, H, W) tensor of integer values between 0 and K-1.
        - target (Tensor): Can be a (N, K, H, W) tensor of
        target scores for N examples and K classes, or (N, H, W) tensor of
        integer values between 0 and K-1.
        """
        # Dimensions check
        assert predicted.size(0) == target.size(0), "number of targets and predicted outputs do not match"
        assert (
            predicted.dim() == 3 or predicted.dim() == 4
        ), "predictions must be of dimension (N, H, W) or (N, K, H, W)"
        assert target.dim() == 3 or target.dim() == 4, "targets must be of dimension (N, H, W) or (N, K, H, W)"

        # If the tensor is in categorical format convert it to integer format
        if predicted.dim() == 4:
            _, predicted = predicted.max(1)
        if target.dim() == 4:
            _, target = target.max(1)

        self.conf_metric.add(predicted.view(-1), target.view(-1))

    def value(self):
        """Computes the IoU and mean IoU.
        The mean computation ignores NaN elements of the IoU array.
        Returns:
            Tuple: (IoU, mIoU). The first output is the per class IoU,
            for K classes it's numpy.ndarray with K elements. The second output,
            is the mean IoU.
        """
        conf_matrix = self.conf_metric.value()
        if self.ignore_index is not None:
            conf_matrix[:, self.ignore_index] = 0
            conf_matrix[self.ignore_index, :] = 0
        true_positive = np.diag(conf_matrix)
        false_positive = np.sum(conf_matrix, 0) - true_positive
        false_negative = np.sum(conf_matrix, 1) - true_positive

        # Just in case we get a division by 0, ignore/hide the error
        with np.errstate(divide="ignore", invalid="ignore"):
            iou = true_positive / (true_positive + false_positive + false_negative)

        return iou, np.nanmean(iou)


================================================
FILE: metric/metric.py
================================================
from typing import Protocol


class Metric(Protocol):
    """Base class for all metrics."""

    def reset(self):
        pass

    def add(self, predicted, target):
        pass

    def value(self):
        pass


================================================
FILE: metric/precipitation_metrics.py
================================================
import torch
from torchmetrics import Metric
import numpy as np


class PrecipitationMetrics(Metric):
    """
    A custom metric for precipitation forecasting that computes both regression metrics (MSE)
    and classification metrics (precision, recall, F1, etc.) based on a threshold.
    """
    
    def __init__(self, threshold=0.5, denormalize=True, dist_sync_on_step=False):
        """
        Args:
            threshold: Threshold to convert continuous predictions into binary masks (in mm/h)
            denormalize: If True, applies the factor to undo normalization
            dist_sync_on_step: Synchronize metric state across processes at each step
        """
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.threshold = threshold
        self.denormalize = denormalize
        self.factor = 47.83 # Factor to denormalize predictions (mm/h)
        
        # Add states for regression metrics
        self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total_loss_denorm", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total_samples", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total_pixels", default=torch.tensor(0), dist_reduce_fx="sum")
        
        # Add states for classification metrics
        self.add_state("total_tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total_fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total_tn", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total_fn", default=torch.tensor(0), dist_reduce_fx="sum")
    
    def update(self, preds, target):
        """
        Update the metric states with new predictions and targets.
        
        Args:
            preds: Model predictions (normalized)
            target: Ground truth (normalized)
        """
        # Check for NaN values
        if torch.isnan(preds).any() or torch.isnan(target).any():
            print("Warning: NaN values detected in predictions or targets")
            return
        
        # Make sure preds and target have the same shape
        if preds.shape != target.shape:
            if len(preds.shape) < len(target.shape):
                preds = preds.unsqueeze(0)
            elif len(preds.shape) > len(target.shape):
                preds = preds.squeeze()
                # If squeezing made it too small, add dimension back
                if len(preds.shape) < len(target.shape):
                    preds = preds.unsqueeze(0)
        
        # Calculate MSE loss (normalized)
        batch_size = target.size(0)
        loss = torch.nn.functional.mse_loss(preds, target, reduction="sum") / batch_size
        self.total_loss += loss
        self.total_samples += batch_size  # Use batch_size instead of 1
        self.total_pixels += target.numel()

        # Denormalize if needed
        if self.denormalize:
            preds_updated = preds * self.factor
            target_updated = target * self.factor

            # Calculate denormalized MSE loss
            loss_denorm = torch.nn.functional.mse_loss(preds_updated, target_updated, reduction="sum") / batch_size
            self.total_loss_denorm += loss_denorm
        else:
            preds_updated = preds
            target_updated = target
            
        # Convert to mm/h for classification metrics (multiply by 12 for 5min to hourly rate)
        preds_hourly = preds_updated * 12
        target_hourly = target_updated * 12
        
        # Apply threshold to get binary masks
        preds_mask = preds_hourly > self.threshold
        target_mask = target_hourly > self.threshold
        
        # Compute confusion matrix
        confusion = target_mask.view(-1) * 2 + preds_mask.view(-1)
        bincount = torch.bincount(confusion, minlength=4)
        
        # Update confusion matrix states
        self.total_tn += bincount[0]
        self.total_fp += bincount[1]
        self.total_fn += bincount[2]
        self.total_tp += bincount[3]
    
    def compute(self):
        """
        Compute the final metrics from the accumulated states.
        
        Returns:
            A dictionary containing all metrics.
        """
        # Compute regression metrics
        mse = self.total_loss / self.total_samples
        mse_denorm = self.total_loss_denorm / self.total_samples if self.denormalize else torch.tensor(float('nan'))
        mse_pixel = self.total_loss_denorm / self.total_pixels if self.denormalize else torch.tensor(float('nan'))
        
        # Compute classification metrics
        precision = self.total_tp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan'))
        recall = self.total_tp / (self.total_tp + self.total_fn) if (self.total_tp + self.total_fn) > 0 else torch.tensor(float('nan'))
        accuracy = (self.total_tp + self.total_tn) / (self.total_tp + self.total_tn + self.total_fp + self.total_fn)
        
        # F1 score
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(float('nan'))
        
        # Critical Success Index (CSI) or Threat Score
        csi = self.total_tp / (self.total_tp + self.total_fn + self.total_fp) if (self.total_tp + self.total_fn + self.total_fp) > 0 else torch.tensor(float('nan'))
        
        # False Alarm Ratio (FAR)
        far = self.total_fp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan'))
        
        # Heidke Skill Score (HSS)
        denom = ((self.total_tp + self.total_fn) * (self.total_fn + self.total_tn) + 
                 (self.total_tp + self.total_fp) * (self.total_fp + self.total_tn))
        hss = ((self.total_tp * self.total_tn) - (self.total_fn * self.total_fp)) / denom if denom > 0 else torch.tensor(float('nan'))
        
        return {
            "mse": mse,
            "mse_denorm": mse_denorm,
            "mse_pixel": mse_pixel,
            "precision": precision,
            "recall": recall,
            "accuracy": accuracy,
            "f1": f1,
            "csi": csi,
            "far": far,
            "hss": hss
        } 

================================================
FILE: models/SmaAt_UNet.py
================================================
from torch import nn
from models.unet_parts import OutConv
from models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS
from models.layers import CBAM


class SmaAt_UNet(nn.Module):
    def __init__(
        self,
        n_channels,
        n_classes,
        kernels_per_layer=2,
        bilinear=True,
        reduction_ratio=16,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        kernels_per_layer = kernels_per_layer
        self.bilinear = bilinear
        reduction_ratio = reduction_ratio

        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
        factor = 2 if self.bilinear else 1
        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
        self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)
        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)

        self.outc = OutConv(64, self.n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x1Att = self.cbam1(x1)
        x2 = self.down1(x1)
        x2Att = self.cbam2(x2)
        x3 = self.down2(x2)
        x3Att = self.cbam3(x3)
        x4 = self.down3(x3)
        x4Att = self.cbam4(x4)
        x5 = self.down4(x4)
        x5Att = self.cbam5(x5)
        x = self.up1(x5Att, x4Att)
        x = self.up2(x, x3Att)
        x = self.up3(x, x2Att)
        x = self.up4(x, x1Att)
        logits = self.outc(x)
        return logits


================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/layers.py
================================================
import torch
from torch import nn
import torch.nn.functional as F


# Taken from https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14
class DepthToSpace(nn.Module):
    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs**2), H, W)  # (N, bs, bs, C//bs^2, H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)
        x = x.view(N, C // (self.bs**2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)
        return x


class SpaceToDepth(nn.Module):
    # Expects the following shape: Batch, Channel, Height, Width
    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs)  # (N, C, H//bs, bs, W//bs, bs)
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)
        x = x.view(N, C * (self.bs**2), H // self.bs, W // self.bs)  # (N, C*bs^2, H//bs, W//bs)
        return x


class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, output_channels, kernel_size, padding=0, kernels_per_layer=1):
        super().__init__()
        # In Tensorflow DepthwiseConv2D has depth_multiplier instead of kernels_per_layer
        self.depthwise = nn.Conv2d(
            in_channels,
            in_channels * kernels_per_layer,
            kernel_size=kernel_size,
            padding=padding,
            groups=in_channels,
        )
        self.pointwise = nn.Conv2d(in_channels * kernels_per_layer, output_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class DoubleDense(nn.Module):
    def __init__(self, in_channels, hidden_neurons, output_channels):
        super().__init__()
        self.dense1 = nn.Linear(in_channels, out_features=hidden_neurons)
        self.dense2 = nn.Linear(in_features=hidden_neurons, out_features=hidden_neurons // 2)
        self.dense3 = nn.Linear(in_features=hidden_neurons // 2, out_features=output_channels)

    def forward(self, x):
        out = F.relu(self.dense1(x.view(x.size(0), -1)))
        out = F.relu(self.dense2(out))
        out = self.dense3(out)
        return out


class DoubleDSConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_ds_conv = nn.Sequential(
            DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_ds_conv(x)


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class ChannelAttention(nn.Module):
    def __init__(self, input_channels, reduction_ratio=16):
        super().__init__()
        self.input_channels = input_channels
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        #  https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py
        #  uses Convolutions instead of Linear
        self.MLP = nn.Sequential(
            Flatten(),
            nn.Linear(input_channels, input_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(input_channels // reduction_ratio, input_channels),
        )

    def forward(self, x):
        # Take the input and apply average and max pooling
        avg_values = self.avg_pool(x)
        max_values = self.max_pool(x)
        out = self.MLP(avg_values) + self.MLP(max_values)
        scale = x * torch.sigmoid(out).unsqueeze(2).unsqueeze(3).expand_as(x)
        return scale


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(1)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        out = self.bn(out)
        scale = x * torch.sigmoid(out)
        return scale


class CBAM(nn.Module):
    def __init__(self, input_channels, reduction_ratio=16, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(input_channels, reduction_ratio=reduction_ratio)
        self.spatial_att = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        out = self.channel_att(x)
        out = self.spatial_att(out)
        return out


================================================
FILE: models/regression_lightning.py
================================================
import lightning.pytorch as pl
import torch
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from utils import dataset_precip
import argparse
import numpy as np
from metric.precipitation_metrics import PrecipitationMetrics
from utils.formatting import make_metrics_str

class UNetBase(pl.LightningModule):
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            "--model",
            type=str,
            default="UNet",
            choices=["UNet", "UNetDS", "UNetAttention", "UNetDSAttention", "PersistenceModel"],
        )
        parser.add_argument("--n_channels", type=int, default=12)
        parser.add_argument("--n_classes", type=int, default=1)
        parser.add_argument("--kernels_per_layer", type=int, default=1)
        parser.add_argument("--bilinear", type=bool, default=True)
        parser.add_argument("--reduction_ratio", type=int, default=16)
        parser.add_argument("--lr_patience", type=int, default=5)
        parser.add_argument("--threshold", type=float, default=0.5)
        return parser

    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.train_metrics = PrecipitationMetrics(
            threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5
        )
        self.val_metrics = PrecipitationMetrics(
            threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5
        )
        self.test_metrics = PrecipitationMetrics(
            threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5
        )

    def forward(self, x):
        pass

    def configure_optimizers(self):
        opt = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = {
            "scheduler": optim.lr_scheduler.ReduceLROnPlateau(
                opt, mode="min", factor=0.1, patience=self.hparams.lr_patience
            ),
            "monitor": "val_loss",  # Default: val_loss
        }
        return [opt], [scheduler]

    def loss_func(self, y_pred, y_true):
        # Ensure consistent shapes before computing loss
        if y_pred.dim() > y_true.dim():
            y_pred = y_pred.squeeze(1)
        elif y_true.dim() > y_pred.dim():
            y_pred = y_pred.unsqueeze(1)

        # reduction="mean" is average of every pixel, but I want average of image
        return nn.functional.mse_loss(y_pred, y_true, reduction="sum") / y_true.size(0)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = self.loss_func(y_pred, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        # Update training metrics with detached tensors
        self.train_metrics.update(y_pred.detach(), y.detach())
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = self.loss_func(y_pred, y)
        self.log("val_loss", loss, prog_bar=True)
        
        # Update validation metrics with detached tensors
        self.val_metrics.update(y_pred.detach(), y.detach())

    def test_step(self, batch, batch_idx):
        """Calculate the loss (MSE per default) and other metrics on the test set normalized and denormalized."""
        x, y = batch
        y_pred = self(x)

        # Update the test metrics
        self.test_metrics.update(y_pred.detach(), y.detach())
    
    def on_test_epoch_end(self):
        """Compute and log all metrics at the end of the test epoch."""
        test_metrics_dict = self.test_metrics.compute()
        
        # Print all metrics in one line
        print(f"\n\nEpoch {self.current_epoch} - Test Metrics: {make_metrics_str(test_metrics_dict)}")
        
        # Reset the metrics for the next test epoch
        self.test_metrics.reset()

    def on_train_epoch_end(self):
        """Compute and log all metrics at the end of the training epoch."""
        train_metrics_dict = self.train_metrics.compute()
        
        print(f"\n\nEpoch {self.current_epoch} - Train Metrics: {make_metrics_str(train_metrics_dict)}")
        self.train_metrics.reset()

    def on_validation_epoch_end(self):
        """Compute and log all metrics at the end of the validation epoch."""
        val_metrics_dict = self.val_metrics.compute()
        
        print(f"\n\nEpoch {self.current_epoch} - Validation Metrics: {make_metrics_str(val_metrics_dict)}")
        self.val_metrics.reset()


class PrecipRegressionBase(UNetBase):
    @staticmethod
    def add_model_specific_args(parent_parser):
        parent_parser = UNetBase.add_model_specific_args(parent_parser)
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--num_input_images", type=int, default=12)
        parser.add_argument("--num_output_images", type=int, default=6)
        parser.add_argument("--valid_size", type=float, default=0.1)
        parser.add_argument("--use_oversampled_dataset", type=bool, default=True)
        parser.n_channels = parser.parse_args().num_input_images
        parser.n_classes = 1
        return parser

    def __init__(self, hparams):
        super().__init__(hparams=hparams)
        self.train_dataset = None
        self.valid_dataset = None
        self.train_sampler = None
        self.valid_sampler = None

    def prepare_data(self):
        # train_transform = transforms.Compose([
        #     transforms.RandomHorizontalFlip()]
        # )
        train_transform = None
        valid_transform = None
        precip_dataset = (
            dataset_precip.precipitation_maps_oversampled_h5
            if self.hparams.use_oversampled_dataset
            else dataset_precip.precipitation_maps_h5
        )
        self.train_dataset = precip_dataset(
            in_file=self.hparams.dataset_folder,
            num_input_images=self.hparams.num_input_images,
            num_output_images=self.hparams.num_output_images,
            train=True,
            transform=train_transform,
        )
        self.valid_dataset = precip_dataset(
            in_file=self.hparams.dataset_folder,
            num_input_images=self.hparams.num_input_images,
            num_output_images=self.hparams.num_output_images,
            train=True,
            transform=valid_transform,
        )

        num_train = len(self.train_dataset)
        indices = list(range(num_train))
        split = int(np.floor(self.hparams.valid_size * num_train))

        np.random.shuffle(indices)

        train_idx, valid_idx = indices[split:], indices[:split]
        self.train_sampler = SubsetRandomSampler(train_idx)
        self.valid_sampler = SubsetRandomSampler(valid_idx)

    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            sampler=self.train_sampler,
            pin_memory=True,
            # The following can/should be tweaked depending on the number of CPU cores
            num_workers=1,
            persistent_workers=True,
        )
        return train_loader

    def val_dataloader(self):
        valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=self.hparams.batch_size,
            sampler=self.valid_sampler,
            pin_memory=True,
            # The following can/should be tweaked depending on the number of CPU cores
            num_workers=1,
            persistent_workers=True,
        )
        return valid_loader


class PersistenceModel(UNetBase):
    def forward(self, x):
        return x[:, -1:, :, :]


================================================
FILE: models/unet_parts.py
================================================
""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


================================================
FILE: models/unet_parts_depthwise_separable.py
================================================
""" Parts of the U-Net model """

# Base model taken from: https://github.com/milesial/Pytorch-UNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.layers import DepthwiseSeparableConv


class DoubleConvDS(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None, kernels_per_layer=1):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            DepthwiseSeparableConv(
                in_channels,
                mid_channels,
                kernel_size=3,
                kernels_per_layer=kernels_per_layer,
                padding=1,
            ),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv(
                mid_channels,
                out_channels,
                kernel_size=3,
                kernels_per_layer=kernels_per_layer,
                padding=1,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class DownDS(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, kernels_per_layer=1):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer),
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class UpDS(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True, kernels_per_layer=1):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConvDS(
                in_channels,
                out_channels,
                in_channels // 2,
                kernels_per_layer=kernels_per_layer,
            )
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


================================================
FILE: models/unet_precip_regression_lightning.py
================================================
from models.unet_parts import Down, DoubleConv, Up, OutConv
from models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS
from models.layers import CBAM
from models.regression_lightning import PrecipRegressionBase, PersistenceModel


class UNet(PrecipRegressionBase):
    def __init__(self, hparams):
        super().__init__(hparams=hparams)
        self.n_channels = self.hparams.n_channels
        self.n_classes = self.hparams.n_classes
        self.bilinear = self.hparams.bilinear

        self.inc = DoubleConv(self.n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if self.bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, self.bilinear)
        self.up2 = Up(512, 256 // factor, self.bilinear)
        self.up3 = Up(256, 128 // factor, self.bilinear)
        self.up4 = Up(128, 64, self.bilinear)

        self.outc = OutConv(64, self.n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


class UNetAttention(PrecipRegressionBase):
    def __init__(self, hparams):
        super().__init__(hparams=hparams)
        self.n_channels = self.hparams.n_channels
        self.n_classes = self.hparams.n_classes
        self.bilinear = self.hparams.bilinear
        reduction_ratio = self.hparams.reduction_ratio

        self.inc = DoubleConv(self.n_channels, 64)
        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
        self.down1 = Down(64, 128)
        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
        self.down2 = Down(128, 256)
        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
        self.down3 = Down(256, 512)
        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
        factor = 2 if self.bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)
        self.up1 = Up(1024, 512 // factor, self.bilinear)
        self.up2 = Up(512, 256 // factor, self.bilinear)
        self.up3 = Up(256, 128 // factor, self.bilinear)
        self.up4 = Up(128, 64, self.bilinear)

        self.outc = OutConv(64, self.n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x1Att = self.cbam1(x1)
        x2 = self.down1(x1)
        x2Att = self.cbam2(x2)
        x3 = self.down2(x2)
        x3Att = self.cbam3(x3)
        x4 = self.down3(x3)
        x4Att = self.cbam4(x4)
        x5 = self.down4(x4)
        x5Att = self.cbam5(x5)
        x = self.up1(x5Att, x4Att)
        x = self.up2(x, x3Att)
        x = self.up3(x, x2Att)
        x = self.up4(x, x1Att)
        logits = self.outc(x)
        return logits


class UNetDS(PrecipRegressionBase):
    def __init__(self, hparams):
        super().__init__(hparams=hparams)
        self.n_channels = self.hparams.n_channels
        self.n_classes = self.hparams.n_classes
        self.bilinear = self.hparams.bilinear
        kernels_per_layer = self.hparams.kernels_per_layer

        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
        factor = 2 if self.bilinear else 1
        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)

        self.outc = OutConv(64, self.n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


class UNetDSAttention(PrecipRegressionBase):
    def __init__(self, hparams):
        super().__init__(hparams=hparams)
        self.n_channels = self.hparams.n_channels
        self.n_classes = self.hparams.n_classes
        self.bilinear = self.hparams.bilinear
        reduction_ratio = self.hparams.reduction_ratio
        kernels_per_layer = self.hparams.kernels_per_layer

        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
        factor = 2 if self.bilinear else 1
        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
        self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)
        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)

        self.outc = OutConv(64, self.n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x1Att = self.cbam1(x1)
        x2 = self.down1(x1)
        x2Att = self.cbam2(x2)
        x3 = self.down2(x2)
        x3Att = self.cbam3(x3)
        x4 = self.down3(x3)
        x4Att = self.cbam4(x4)
        x5 = self.down4(x4)
        x5Att = self.cbam5(x5)
        x = self.up1(x5Att, x4Att)
        x = self.up2(x, x3Att)
        x = self.up3(x, x2Att)
        x = self.up4(x, x1Att)
        logits = self.outc(x)
        return logits


class UNetDSAttention4CBAMs(PrecipRegressionBase):
    def __init__(self, hparams):
        super().__init__(hparams=hparams)
        self.n_channels = self.hparams.n_channels
        self.n_classes = self.hparams.n_classes
        self.bilinear = self.hparams.bilinear
        reduction_ratio = self.hparams.reduction_ratio
        kernels_per_layer = self.hparams.kernels_per_layer

        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
        factor = 2 if self.bilinear else 1
        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)

        self.outc = OutConv(64, self.n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x1Att = self.cbam1(x1)
        x2 = self.down1(x1)
        x2Att = self.cbam2(x2)
        x3 = self.down2(x2)
        x3Att = self.cbam3(x3)
        x4 = self.down3(x3)
        x4Att = self.cbam4(x4)
        x5 = self.down4(x4)
        x = self.up1(x5, x4Att)
        x = self.up2(x, x3Att)
        x = self.up3(x, x2Att)
        x = self.up4(x, x1Att)
        logits = self.outc(x)
        return logits


================================================
FILE: plot_examples.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import models.unet_precip_regression_lightning as unet_regr\n",
    "from utils import dataset_precip\n",
    "\n",
    "dataset = dataset_precip.precipitation_maps_oversampled_h5(\n",
    "            in_file=\"data/precipitation/train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5\",\n",
    "            num_input_images=12,\n",
    "            num_output_images=6, train=False)\n",
    "# 222, 444, 777, 1337 is fine\n",
    "x, y_true = dataset[777]\n",
    "# add batch dimension\n",
    "x = torch.tensor(x).unsqueeze(0)\n",
    "y_true = torch.tensor(y_true).unsqueeze(0)\n",
    "print(x.shape, y_true.shape)\n",
    "\n",
    "models_dir = \"checkpoints/comparison\"\n",
    "precip_models = [m for m in os.listdir(models_dir) if \".ckpt\" in m]\n",
    "print(precip_models)\n",
    "\n",
    "loss_func = nn.MSELoss(reduction=\"sum\")\n",
    "print(\"Persistence\", loss_func(x.squeeze()[-1]*47.83, y_true.squeeze()*47.83).item())\n",
    "\n",
    "plot_images = dict()\n",
    "for i, model_file in enumerate(precip_models):\n",
    "    # model_name = \"\".join(model_file.split(\"_\")[0])\n",
    "    if \"UNetAttention\" in model_file:\n",
    "        model_name = \"UNet with CBAM\"\n",
    "        model = unet_regr.UNetAttention\n",
    "    elif \"UNetDSAttention4CBAMs\" in model_file:\n",
    "        model_name = \"UNetDS Attention 4CBAMs\"\n",
    "        model = unet_regr.UNetDSAttention4CBAMs\n",
    "    elif \"UNetDSAttention\" in model_file:\n",
    "        model_name = \"SmaAt-UNet\"\n",
    "        model = unet_regr.UNetDSAttention\n",
    "    elif \"UNetDS\" in model_file:\n",
    "        model_name = \"UNet with DSC\"\n",
    "        model = unet_regr.UNetDS\n",
    "    elif \"UNet\" in model_file:\n",
    "        model_name = \"UNet\"\n",
    "        model = unet_regr.UNet\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Model not found\")\n",
    "    model = model.load_from_checkpoint(f\"{models_dir}/{model_file}\")\n",
    "    # loaded = torch.load(f\"{models_dir}/{model_file}\")\n",
    "    # model = loaded[\"model\"]\n",
    "    # model.load_state_dict(loaded[\"state_dict\"])\n",
    "    # loss_func = nn.CrossEntropyLoss(weight=torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.]))\n",
    "\n",
    "    model.to(\"cpu\")\n",
    "    model.eval()\n",
    "\n",
    "    # get prediction of model\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(x)\n",
    "    y_pred = y_pred.squeeze()\n",
    "    print(model_name, loss_func(y_pred*47.83, y_true.squeeze()*47.83).item())\n",
    "\n",
    "    plot_images[model_name] = y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#### Plotting ####\n",
    "plt.figure(figsize=(10, 20))\n",
    "plt.subplot(len(precip_models)+2, 1, 1)\n",
    "plt.imshow(y_true.squeeze()*47.83)\n",
    "plt.colorbar()\n",
    "plt.title(\"Ground truth\")\n",
    "plt.axis(\"off\")\n",
    "\n",
    "plt.subplot(len(precip_models)+2, 1, 2)\n",
    "plt.imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
    "plt.colorbar()\n",
    "plt.title(\"Persistence\")\n",
    "plt.axis(\"off\")\n",
    "\n",
    "for i, (model_name, y_pred) in enumerate(plot_images.items()):\n",
    "    print(model_name)\n",
    "    plt.subplot(len(precip_models)+2, 1, i+3)\n",
    "    plt.imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
    "    plt.colorbar()\n",
    "    plt.title(model_name)\n",
    "    plt.axis(\"off\")\n",
    "# plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "fig, axesgrid = plt.subplots(nrows=2, ncols=(len(precip_models)+2)//2, figsize=(20, 10))\n",
    "axes = axesgrid.flat\n",
    "im = axes[0].imshow(y_true.squeeze()*47.83)\n",
    "axes[0].set_title(\"Groundtruth\", fontsize=15)\n",
    "axes[0].set_axis_off()\n",
    "axes[1].imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
    "axes[1].set_title(\"Persistence\", fontsize=15)\n",
    "axes[1].set_axis_off()\n",
    "\n",
    "# for ax in axes.flat[2:]:\n",
    "for i, (model_name, y_pred) in enumerate(plot_images.items()):\n",
    "    print(model_name)\n",
    "#     plt.subplot(len(precip_models)+2, 1, i+3)\n",
    "    axes[i+2].imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
    "    axes[i+2].set_title(model_name, fontsize=15)\n",
    "    axes[i+2].set_axis_off()\n",
    "#     im = ax.imshow(np.random.random((16, 16)), cmap='viridis',\n",
    "#                    vmin=0, vmax=1)\n",
    "\n",
    "output_dir = \"plots\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "fig.subplots_adjust(wspace=0.02)\n",
    "cbar = fig.colorbar(im, ax=axesgrid.ravel().tolist(), shrink=1)\n",
    "cbar.ax.set_ylabel('mm/5min', fontsize=15)\n",
    "plt.savefig(os.path.join(output_dir, \"Precip_example.eps\"))\n",
    "plt.savefig(os.path.join(output_dir, \"Precip_example.png\"), dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}


================================================
FILE: pyproject.toml
================================================
[project]
name = "smaat-unet"
version = "0.1.0"
description = "Code for the paper `SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture`"
authors = [{ name = "Kevin Trebing", email = "Kevin.Trebing@gmx.net" }]
readme = "README.md"
requires-python = ">=3.10.11"
dependencies = [
    "h5py>=3.13.0",
    "lightning>=2.5.0.post0",
    "matplotlib>=3.10.0",
    "numpy>=2.2.3",
    "tensorboard>=2.19.0",
    "torch>=2.6.0",
    "torchsummary>=1.5.1",
    "tqdm>=4.67.1",
]

[dependency-groups]
dev = [
    "mypy>=1.15.0",
    "pre-commit>=4.1.0",
    "ruff>=0.9.7",
]
scripts = [
    "ipykernel>=6.29.5",
]

[[tool.uv.index]]
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
explicit = true

[tool.uv.sources]
torch = [
  { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [
  { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]

[tool.mypy]
python_version = "3.10"
ignore_missing_imports = true

[tool.ruff]
line-length = 120

[tool.ruff.lint]
select = [
    "E",   # pycodestyle Errors
    "W",   # pycodestyle Warnings
    "F",   # pyflakes
    "UP",  # pyupgrade
#    "D",   # pydocstyle
    "B",   # flake8-bugbear
    "C4",  # flake8-comprehensions
    "A",   # flake8-builtins
    "C90", # mccabe complexity
    "DJ",  # flake8-django
    "PIE", # flake8-pie
#    "SIM", # flake8-simplify
]
ignore = [
    "B905", # length-checking in zip() only introduced with Python 3.10 (PEP618)
    "UP007",  # Python version 3.9 does not allow writing union types as X | Y
    "UP038",  # Python version 3.9 does not allow writing union types as X | Y
    "D202",  # No blank lines allowed after function docstring
    "D100",  # Missing docstring in public module
    "D104",  # Missing docstring in public package
    "D205",  # 1 blank line required between summary line and description
    "D203",  # One blank line before class docstring
    "D212",  # Multi-line docstring summary should start at the first line
    "D213",  # Multi-line docstring summary should start at the second line
    "D400",  # First line of docstring should end with a period
    "D404",  # First word of the docstring should not be "This"
    "D415",  # First line should end with a period, question mark, or exclamation point
    "DJ001", # Avoid using `null=True` on string-based fields
]
exclude = [
    ".git",
    ".local",
    ".cache",
    ".venv",
    "./venv",
    ".vscode",
    "__pycache__",
    "docs",
    "build",
    "dist",
    "notebooks",
    "migrations"
]

[tool.ruff.lint.mccabe]
# Flag errors (`C901`) whenever the complexity level exceeds 12.
max-complexity = 12


================================================
FILE: requirements-dev.txt
================================================
# This file was autogenerated by uv via the following command:
#    uv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt
absl-py==2.1.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
async-timeout==5.0.1 ; python_full_version < '3.11'
attrs==25.1.0
cfgv==3.4.0
colorama==0.4.6 ; sys_platform == 'win32'
contourpy==1.3.1
cycler==0.12.1
distlib==0.3.9
filelock==3.17.0
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2025.2.0
grpcio==1.70.0
h5py==3.13.0
identify==2.6.7
idna==3.10
jinja2==3.1.5
kiwisolver==1.4.8
lightning==2.5.0.post0
lightning-utilities==0.12.0
markdown==3.7
markupsafe==3.0.2
matplotlib==3.10.0
mpmath==1.3.0
multidict==6.1.0
mypy==1.15.0
mypy-extensions==1.0.0
networkx==3.4.2
nodeenv==1.9.1
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
packaging==24.2
pillow==11.1.0
platformdirs==4.3.6
pre-commit==4.1.0
propcache==0.3.0
protobuf==5.29.3
pyparsing==3.2.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pyyaml==6.0.2
ruff==0.9.7
setuptools==75.8.0
six==1.17.0
sympy==1.13.1
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tomli==2.2.1 ; python_full_version < '3.11'
torch==2.6.0
torchmetrics==1.6.1
torchsummary==1.5.1
tqdm==4.67.1
triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typing-extensions==4.12.2
virtualenv==20.29.2
werkzeug==3.1.3
yarl==1.18.3


================================================
FILE: requirements-full.txt
================================================
# This file was autogenerated by uv via the following command:
#    uv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt
absl-py==2.1.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
appnope==0.1.4 ; sys_platform == 'darwin'
asttokens==3.0.0
async-timeout==5.0.1 ; python_full_version < '3.11'
attrs==25.1.0
cffi==1.17.1 ; implementation_name == 'pypy'
cfgv==3.4.0
colorama==0.4.6 ; sys_platform == 'win32'
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.12
decorator==5.1.1
distlib==0.3.9
exceptiongroup==1.2.2 ; python_full_version < '3.11'
executing==2.2.0
filelock==3.17.0
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2025.2.0
grpcio==1.70.0
h5py==3.13.0
identify==2.6.7
idna==3.10
ipykernel==6.29.5
ipython==8.32.0
jedi==0.19.2
jinja2==3.1.5
jupyter-client==8.6.3
jupyter-core==5.7.2
kiwisolver==1.4.8
lightning==2.5.0.post0
lightning-utilities==0.12.0
markdown==3.7
markupsafe==3.0.2
matplotlib==3.10.0
matplotlib-inline==0.1.7
mpmath==1.3.0
multidict==6.1.0
mypy==1.15.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
packaging==24.2
parso==0.8.4
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
pillow==11.1.0
platformdirs==4.3.6
pre-commit==4.1.0
prompt-toolkit==3.0.50
propcache==0.3.0
protobuf==5.29.3
psutil==7.0.0
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
pure-eval==0.2.3
pycparser==2.22 ; implementation_name == 'pypy'
pygments==2.19.1
pyparsing==3.2.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
pyyaml==6.0.2
pyzmq==26.2.1
ruff==0.9.7
setuptools==75.8.0
six==1.17.0
stack-data==0.6.3
sympy==1.13.1
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tomli==2.2.1 ; python_full_version < '3.11'
torch==2.6.0
torchmetrics==1.6.1
torchsummary==1.5.1
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typing-extensions==4.12.2
virtualenv==20.29.2
wcwidth==0.2.13
werkzeug==3.1.3
yarl==1.18.3


================================================
FILE: requirements.txt
================================================
# This file was autogenerated by uv via the following command:
#    uv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt
absl-py==2.1.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
async-timeout==5.0.1 ; python_full_version < '3.11'
attrs==25.1.0
colorama==0.4.6 ; sys_platform == 'win32'
contourpy==1.3.1
cycler==0.12.1
filelock==3.17.0
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2025.2.0
grpcio==1.70.0
h5py==3.13.0
idna==3.10
jinja2==3.1.5
kiwisolver==1.4.8
lightning==2.5.0.post0
lightning-utilities==0.12.0
markdown==3.7
markupsafe==3.0.2
matplotlib==3.10.0
mpmath==1.3.0
multidict==6.1.0
networkx==3.4.2
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
packaging==24.2
pillow==11.1.0
propcache==0.3.0
protobuf==5.29.3
pyparsing==3.2.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pyyaml==6.0.2
setuptools==75.8.0
six==1.17.0
sympy==1.13.1
tensorboard==2.19.0
tensorboard-data-server==0.7.2
torch==2.6.0
torchmetrics==1.6.1
torchsummary==1.5.1
tqdm==4.67.1
triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typing-extensions==4.12.2
werkzeug==3.1.3
yarl==1.18.3


================================================
FILE: root.py
================================================
from pathlib import Path

ROOT_DIR = Path(__file__).parent


================================================
FILE: train_SmaAtUNet.py
================================================
from typing import Optional

from models.SmaAt_UNet import SmaAt_UNet
import torch
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from torchvision import transforms

from root import ROOT_DIR
from utils import dataset_VOC
import time
from tqdm import tqdm
from metric import iou
import os


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


def fit(
    epochs,
    model,
    loss_func,
    opt,
    train_dl,
    valid_dl,
    dev=None,
    save_every: Optional[int] = None,
    tensorboard: bool = False,
    earlystopping=None,
    lr_scheduler=None,
):
    writer = None
    if tensorboard:
        from torch.utils.tensorboard import SummaryWriter

        writer = SummaryWriter(comment=f"{model.__class__.__name__}")

    if dev is None:
        dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    start_time = time.time()
    best_mIoU = -1.0
    earlystopping_counter = 0
    for epoch in tqdm(range(epochs), desc="Epochs", leave=True):
        model.train()
        train_loss = 0.0
        for _, (xb, yb) in enumerate(tqdm(train_dl, desc="Batches", leave=False)):
            # for i, (xb, yb) in enumerate(train_dl):
            loss = loss_func(model(xb.to(dev)), yb.to(dev))
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_loss += loss.item()
            # if i > 100:
            #     break
        train_loss /= len(train_dl)

        # Reduce learning rate after epoch
        # scheduler.step()

        # Calc validation loss
        val_loss = 0.0
        iou_metric = iou.IoU(21, normalized=False)
        model.eval()
        with torch.no_grad():
            for xb, yb in tqdm(valid_dl, desc="Validation", leave=False):
                # for xb, yb in valid_dl:
                y_pred = model(xb.to(dev))
                loss = loss_func(y_pred, yb.to(dev))
                val_loss += loss.item()
                # Calculate mean IOU
                pred_class = torch.argmax(nn.functional.softmax(y_pred, dim=1), dim=1)
                iou_metric.add(pred_class, target=yb)

            iou_class, mean_iou = iou_metric.value()
            val_loss /= len(valid_dl)

        # Save the model with the best mean IoU
        if mean_iou > best_mIoU:
            os.makedirs("checkpoints", exist_ok=True)
            torch.save(
                {
                    "model": model,
                    "epoch": epoch,
                    "state_dict": model.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                    "val_loss": val_loss,
                    "train_loss": train_loss,
                    "mIOU": mean_iou,
                },
                ROOT_DIR / "checkpoints" / f"best_mIoU_model_{model.__class__.__name__}.pt",
            )
            best_mIoU = mean_iou
            earlystopping_counter = 0

        else:
            earlystopping_counter += 1
            if earlystopping is not None and earlystopping_counter >= earlystopping:
                print(f"Stopping early --> mean IoU has not decreased over {earlystopping} epochs")
                break

        print(
            f"Epoch: {epoch:5d}, Time: {(time.time() - start_time) / 60:.3f} min,"
            f"Train_loss: {train_loss:2.10f}, Val_loss: {val_loss:2.10f},",
            f"mIOU: {mean_iou:.10f},",
            f"lr: {get_lr(opt)},",
            f"Early stopping counter: {earlystopping_counter}/{earlystopping}" if earlystopping is not None else "",
        )

        if writer:
            # add to tensorboard
            writer.add_scalar("Loss/train", train_loss, epoch)
            writer.add_scalar("Loss/val", val_loss, epoch)
            writer.add_scalar("Metric/mIOU", mean_iou, epoch)
            writer.add_scalar("Parameters/learning_rate", get_lr(opt), epoch)
        if save_every is not None and epoch % save_every == 0:
            # save model
            torch.save(
                {
                    "model": model,
                    "epoch": epoch,
                    "state_dict": model.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                    # 'scheduler_state_dict': scheduler.state_dict(),
                    "val_loss": val_loss,
                    "train_loss": train_loss,
                    "mIOU": mean_iou,
                },
                ROOT_DIR / "checkpoints" / f"model_{model.__class__.__name__}_epoch_{epoch}.pt",
            )
        if lr_scheduler is not None:
            lr_scheduler.step(mean_iou)


if __name__ == "__main__":
    dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    dataset_folder = ROOT_DIR / "data" / "VOCdevkit"
    batch_size = 8
    learning_rate = 0.001
    epochs = 200
    earlystopping = 30
    save_every = 1

    # Load your dataset here
    transformations = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
    voc_dataset_train = dataset_VOC.VOCSegmentation(
        root=dataset_folder,
        image_set="train",
        transformations=transformations,
        augmentations=True,
    )
    voc_dataset_val = dataset_VOC.VOCSegmentation(
        root=dataset_folder,
        image_set="val",
        transformations=transformations,
        augmentations=False,
    )
    train_dl = DataLoader(
        dataset=voc_dataset_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )
    valid_dl = DataLoader(
        dataset=voc_dataset_val,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    # Load SmaAt-UNet
    model = SmaAt_UNet(n_channels=3, n_classes=21)
    # Move model to device
    model.to(dev)
    # Define Optimizer and loss
    opt = optim.Adam(model.parameters(), lr=learning_rate)
    loss_func = nn.CrossEntropyLoss().to(dev)

    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.1, patience=4)
    # Train network
    fit(
        epochs=epochs,
        model=model,
        loss_func=loss_func,
        opt=opt,
        train_dl=train_dl,
        valid_dl=valid_dl,
        dev=dev,
        save_every=save_every,
        tensorboard=True,
        earlystopping=earlystopping,
        lr_scheduler=lr_scheduler,
    )


================================================
FILE: train_precip_lightning.py
================================================
from root import ROOT_DIR

import lightning.pytorch as pl
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor,
    EarlyStopping,
)
from lightning.pytorch import loggers
import argparse
from models import unet_precip_regression_lightning as unet_regr
from lightning.pytorch.tuner import Tuner


def train_regression(hparams, find_batch_size_automatically: bool = False):
    if hparams.model == "UNetDSAttention":
        net = unet_regr.UNetDSAttention(hparams=hparams)
    elif hparams.model == "UNetAttention":
        net = unet_regr.UNetAttention(hparams=hparams)
    elif hparams.model == "UNet":
        net = unet_regr.UNet(hparams=hparams)
    elif hparams.model == "UNetDS":
        net = unet_regr.UNetDS(hparams=hparams)
    else:
        raise NotImplementedError(f"Model '{hparams.model}' not implemented")

    default_save_path = ROOT_DIR / "lightning" / "precip_regression"

    checkpoint_callback = ModelCheckpoint(
        dirpath=default_save_path / net.__class__.__name__,
        filename=net.__class__.__name__ + "_rain_threshold_50_{epoch}-{val_loss:.6f}",
        save_top_k=1,
        verbose=False,
        monitor="val_loss",
        mode="min",
    )

    last_checkpoint_callback = ModelCheckpoint(
        dirpath=default_save_path / net.__class__.__name__,
        filename=net.__class__.__name__ + "_rain_threshold_50_{epoch}-{val_loss:.6f}_last",
        save_top_k=1,
        verbose=False,
    )
    
    lr_monitor = LearningRateMonitor()
    tb_logger = loggers.TensorBoardLogger(save_dir=default_save_path, name=net.__class__.__name__)

    earlystopping_callback = EarlyStopping(
        monitor="val_loss",
        mode="min",
        patience=hparams.es_patience,
    )
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        fast_dev_run=hparams.fast_dev_run,
        max_epochs=hparams.epochs,
        default_root_dir=default_save_path,
        logger=tb_logger,
        callbacks=[checkpoint_callback, last_checkpoint_callback, earlystopping_callback, lr_monitor],
        val_check_interval=hparams.val_check_interval,
    )

    if find_batch_size_automatically:
        tuner = Tuner(trainer)

        # Auto-scale batch size by growing it exponentially (default)
        tuner.scale_batch_size(net, mode="binsearch")

    # This can be used to speed up training with newer GPUs:
    # https://lightning.ai/docs/pytorch/stable/advanced/speed.html#low-precision-matrix-multiplication
    # torch.set_float32_matmul_precision('medium')

    trainer.fit(model=net, ckpt_path=hparams.resume_from_checkpoint)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser = unet_regr.PrecipRegressionBase.add_model_specific_args(parser)

    parser.add_argument(
        "--dataset_folder",
        default=ROOT_DIR / "data" / "precipitation" / "RAD_NL25_RAC_5min_train_test_2016-2019.h5",
        type=str,
    )
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--learning_rate", type=float, default=0.001)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--fast_dev_run", type=bool, default=False)
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--val_check_interval", type=float, default=None)

    args = parser.parse_args()

    # args.fast_dev_run = True
    args.n_channels = 12
    # args.gpus = 1
    args.model = "UNetDSAttention"
    args.lr_patience = 4
    args.es_patience = 15
    # args.val_check_interval = 0.25
    args.kernels_per_layer = 2
    args.use_oversampled_dataset = True
    args.dataset_folder = (
        ROOT_DIR / "data" / "precipitation" / "train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5"
    )
    # args.resume_from_checkpoint = f"lightning/precip_regression/{args.model}/UNetDSAttention.ckpt"

    # train_regression(args, find_batch_size_automatically=False)

    # All the models below will be trained
    for m in ["UNet", "UNetDS", "UNetAttention", "UNetDSAttention"]:
        args.model = m
        print(f"Start training model: {m}")
        train_regression(args, find_batch_size_automatically=False)

================================================
FILE: utils/__init__.py
================================================


================================================
FILE: utils/data_loader_precip.py
================================================
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from utils import dataset_precip
from root import ROOT_DIR


# Taken from: https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb
def get_train_valid_loader(
    data_dir,
    batch_size,
    random_seed,
    num_input_images,
    num_output_images,
    augment,
    classification,
    valid_size=0.1,
    shuffle=True,
    num_workers=4,
    pin_memory=False,
):
    """
    Utility function for loading and returning train and valid
    multi-process iterators over the CIFAR-10 dataset. A sample
    9x9 grid of the images can be optionally displayed.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to apply the data augmentation scheme
      mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert (valid_size >= 0) and (valid_size <= 1), error_msg

    # Since I am not dealing with RGB images I do not need this
    # normalize = transforms.Normalize(
    #     mean=[0.4914, 0.4822, 0.4465],
    #     std=[0.2023, 0.1994, 0.2010],
    # )

    # define transforms
    valid_transform = None
    # valid_transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         normalize,
    # ])
    if augment:
        # TODO flipping, rotating, sequence flipping (torch.flip seems very expensive)
        train_transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                # transforms.ToTensor(),
                # normalize,
            ]
        )
    else:
        train_transform = None
        # train_transform = transforms.Compose([
        #     transforms.ToTensor(),
        #     normalize,
        # ])
    if classification:
        # load the dataset
        train_dataset = dataset_precip.precipitation_maps_classification_h5(
            in_file=data_dir,
            num_input_images=num_input_images,
            img_to_predict=num_output_images,
            train=True,
            transform=train_transform,
        )

        valid_dataset = dataset_precip.precipitation_maps_classification_h5(
            in_file=data_dir,
            num_input_images=num_input_images,
            img_to_predict=num_output_images,
            train=True,
            transform=valid_transform,
        )
    else:
        # load the dataset
        train_dataset = dataset_precip.precipitation_maps_h5(
            in_file=data_dir,
            num_input_images=num_input_images,
            num_output_images=num_output_images,
            train=True,
            transform=train_transform,
        )

        valid_dataset = dataset_precip.precipitation_maps_h5(
            in_file=data_dir,
            num_input_images=num_input_images,
            num_output_images=num_output_images,
            train=True,
            transform=valid_transform,
        )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        sampler=valid_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_loader, valid_loader


def get_test_loader(
    data_dir,
    batch_size,
    num_input_images,
    num_output_images,
    classification,
    shuffle=False,
    num_workers=4,
    pin_memory=False,
):
    """
    Utility function for loading and returning a multi-process
    test iterator over the CIFAR-10 dataset.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - shuffle: whether to shuffle the dataset after every epoch.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - data_loader: test set iterator.
    """
    # Since I am not dealing with RGB images I do not need this
    # normalize = transforms.Normalize(
    #     mean=[0.485, 0.456, 0.406],
    #     std=[0.229, 0.224, 0.225],
    # )

    # define transform
    transform = None
    # transform = transforms.Compose([
    #     transforms.ToTensor(),
    #     # normalize,
    # ])
    if classification:
        dataset = dataset_precip.precipitation_maps_classification_h5(
            in_file=data_dir,
            num_input_images=num_input_images,
            img_to_predict=num_output_images,
            train=False,
            transform=transform,
        )
    else:
        dataset = dataset_precip.precipitation_maps_h5(
            in_file=data_dir,
            num_input_images=num_input_images,
            num_output_images=num_output_images,
            train=False,
            transform=transform,
        )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return data_loader


if __name__ == "__main__":
    folder = ROOT_DIR / "data" / "precipitation"
    data = "RAD_NL25_RAC_5min_train_test_2016-2019.h5"
    train_dl, valid_dl = get_train_valid_loader(
        folder / data,
        batch_size=8,
        random_seed=1337,
        num_input_images=12,
        num_output_images=6,
        classification=True,
        augment=False,
        valid_size=0.1,
        shuffle=True,
        num_workers=4,
        pin_memory=False,
    )
    for xb, yb in train_dl:
        print("xb.shape: ", xb.shape)
        print("yb.shape: ", yb.shape)
        break


================================================
FILE: utils/dataset_VOC.py
================================================
import random
import torch
from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
from torchvision import transforms
import torchvision.transforms.functional as TF
import numpy as np
import matplotlib.pyplot as plt


def get_pascal_labels():
    """Load the mapping that associates pascal classes with label colors
    Returns:
        np.ndarray with dimensions (21, 3)
    """
    return np.asarray(
        [
            [0, 0, 0],
            [128, 0, 0],
            [0, 128, 0],
            [128, 128, 0],
            [0, 0, 128],
            [128, 0, 128],
            [0, 128, 128],
            [128, 128, 128],
            [64, 0, 0],
            [192, 0, 0],
            [64, 128, 0],
            [192, 128, 0],
            [64, 0, 128],
            [192, 0, 128],
            [64, 128, 128],
            [192, 128, 128],
            [0, 64, 0],
            [128, 64, 0],
            [0, 192, 0],
            [128, 192, 0],
            [0, 64, 128],
        ]
    )


def decode_segmap(label_mask, plot=False):
    """Decode segmentation class labels into a color image
    Args:
        label_mask (np.ndarray): an (M,N) array of integer values denoting
          the class label at each spatial location.
        plot (bool, optional): whether to show the resulting color image
          in a figure.
    Returns:
        (np.ndarray, optional): the resulting decoded color image.
    """
    label_colours = get_pascal_labels()
    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()
    for ll in range(21):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    if plot:
        plt.imshow(rgb)
        plt.show()
    else:
        return rgb


class VOCSegmentation(Dataset):
    CLASS_NAMES = [
        "background",
        "aeroplane",
        "bicycle",
        "bird",
        "boat",
        "bottle",
        "bus",
        "car",
        "cat",
        "chair",
        "cow",
        "diningtable",
        "dog",
        "horse",
        "motorbike",
        "person",
        "potted-plant",
        "sheep",
        "sofa",
        "train",
        "tv/monitor",
    ]

    def __init__(self, root: Path, image_set="train", transformations=None, augmentations=False):
        super().__init__()
        assert image_set in ["train", "val", "trainval"]

        voc_root = root / "VOC2012"
        image_dir = voc_root / "JPEGImages"
        mask_dir = voc_root / "SegmentationClass"

        splits_dir = voc_root / "ImageSets" / "Segmentation"

        split_f = splits_dir / (image_set.rstrip("\n") + ".txt")
        with open(split_f) as f:
            file_names = [x.strip() for x in f.readlines()]

        self.images = [image_dir / (x + ".jpg") for x in file_names]
        self.masks = [mask_dir / (x + ".png") for x in file_names]
        assert len(self.images) == len(self.masks)

        self.transformations = transformations
        self.augmentations = augmentations

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert("RGB")
        target = Image.open(self.masks[index])

        if self.transformations is not None:
            img = self.transformations(img)
            target = self.transformations(target)
        # Apply augmentations to both input and target
        if self.augmentations:
            img, target = self.apply_augmentations(img, target)

        # Convert the RGB image to a tensor
        toTensorTransform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )
        img = toTensorTransform(img)
        # Convert target to long tensor
        target = torch.from_numpy(np.array(target)).long()
        target[target == 255] = 0

        return img, target

    def __len__(self):
        return len(self.images)

    def apply_augmentations(self, img, target):
        # Horizontal flip
        if random.random() > 0.5:
            img = TF.hflip(img)
            target = TF.hflip(target)
        # Random Rotation (clockwise and counterclockwise)
        if random.random() > 0.5:
            degrees = 10
            if random.random() > 0.5:
                degrees *= -1
            img = TF.rotate(img, degrees)
            target = TF.rotate(target, degrees)
        # Brighten or darken image (only applied to input image)
        if random.random() > 0.5:
            brightness = 1.2
            if random.random() > 0.5:
                brightness -= 0.4
            img = TF.adjust_brightness(img, brightness)
        return img, target


================================================
FILE: utils/dataset_precip.py
================================================
from torch.utils.data import Dataset
import h5py
import numpy as np


class precipitation_maps_h5(Dataset):
    def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None):
        super().__init__()

        self.file_name = in_file
        self.n_images, self.nx, self.ny = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape

        self.num_input = num_input_images
        self.num_output = num_output_images
        self.sequence_length = num_input_images + num_output_images

        self.train = train
        # Dataset is all the images
        self.size_dataset = self.n_images - (num_input_images + num_output_images)
        # self.size_dataset = int(self.n_images/(num_input_images+num_output_images))
        self.transform = transform
        self.dataset = None

    def __getitem__(self, index):
        # min_feature_range = 0.0
        # max_feature_range = 1.0
        # with h5py.File(self.file_name, 'r') as dataFile:
        #     dataset = dataFile["train" if self.train else "test"]['images'][index:index+self.sequence_length]
        # load the file here (load as singleton)
        if self.dataset is None:
            self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][
                "images"
            ]
        imgs = np.array(self.dataset[index : index + self.sequence_length], dtype="float32")

        # add transforms
        if self.transform is not None:
            imgs = self.transform(imgs)
        input_img = imgs[: self.num_input]
        target_img = imgs[-1]

        return input_img, target_img

    def __len__(self):
        return self.size_dataset


class precipitation_maps_oversampled_h5(Dataset):
    def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None):
        super().__init__()

        self.file_name = in_file
        self.samples, _, _, _ = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape

        self.num_input = num_input_images
        self.num_output = num_output_images

        self.train = train
        # self.size_dataset = int(self.n_images/(num_input_images+num_output_images))
        self.transform = transform
        self.dataset = None

    def __getitem__(self, index):
        # load the file here (load as singleton)
        if self.dataset is None:
            self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][
                "images"
            ]
        imgs = np.array(self.dataset[index], dtype="float32")

        # add transforms
        if self.transform is not None:
            imgs = self.transform(imgs)
        input_img = imgs[: self.num_input]
        target_img = imgs[-1]

        return input_img, target_img

    def __len__(self):
        return self.samples


class precipitation_maps_classification_h5(Dataset):
    def __init__(self, in_file, num_input_images, img_to_predict, train=True, transform=None):
        super().__init__()

        self.file_name = in_file
        self.n_images, self.nx, self.ny = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape

        self.num_input = num_input_images
        self.img_to_predict = img_to_predict
        self.sequence_length = num_input_images + img_to_predict
        self.bins = np.array([0.0, 0.5, 1, 2, 5, 10, 30])

        self.train = train
        # Dataset is all the images
        self.size_dataset = self.n_images - (num_input_images + img_to_predict)
        # self.size_dataset = int(self.n_images/(num_input_images+num_output_images))
        self.transform = transform
        self.dataset = None

    def __getitem__(self, index):
        # min_feature_range = 0.0
        # max_feature_range = 1.0
        # with h5py.File(self.file_name, 'r') as dataFile:
        #     dataset = dataFile["train" if self.train else "test"]['images'][index:index+self.sequence_length]
        # load the file here (load as singleton)
        if self.dataset is None:
            self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][
                "images"
            ]
        imgs = np.array(self.dataset[index : index + self.sequence_length], dtype="float32")

        # add transforms
        if self.transform is not None:
            imgs = self.transform(imgs)
        input_img = imgs[: self.num_input]
        # put the img in buckets
        target_img = imgs[-1]
        # target_img is normalized by dividing through the highest value of the training set. We reverse this.
        # Then target_img is in mm/5min. The bins have the unit mm/hour. Therefore we multiply the img by 12
        buckets = np.digitize(target_img * 47.83 * 12, self.bins, right=True)

        return input_img, buckets

    def __len__(self):
        return self.size_dataset


================================================
FILE: utils/formatting.py
================================================
import torch
import numpy as np


def make_metrics_str(metrics_dict):
    return " | ".join([f"{name}: {value.item() if isinstance(value, torch.Tensor) else value:.4f}" 
                             for name, value in metrics_dict.items() 
                             if not (isinstance(value, torch.Tensor) and torch.isnan(value)) and 
                                not (not isinstance(value, torch.Tensor) and np.isnan(value))])


================================================
FILE: utils/model_classes.py
================================================
from models import unet_precip_regression_lightning as unet_regr
import lightning.pytorch as pl


def get_model_class(model_file) -> tuple[type[pl.LightningModule], str]:
    # This is for some nice plotting
    if "UNetAttention" in model_file:
        model_name = "UNet Attention"
        model = unet_regr.UNetAttention
    elif "UNetDSAttention4kpl" in model_file:
        model_name = "UNetDS Attention with 4kpl"
        model = unet_regr.UNetDSAttention
    elif "UNetDSAttention1kpl" in model_file:
        model_name = "UNetDS Attention with 1kpl"
        model = unet_regr.UNetDSAttention
    elif "UNetDSAttention4CBAMs" in model_file:
        model_name = "UNetDS Attention 4CBAMs"
        model = unet_regr.UNetDSAttention4CBAMs
    elif "UNetDSAttention" in model_file:
        model_name = "SmaAt-UNet"
        model = unet_regr.UNetDSAttention
    elif "UNetDS" in model_file:
        model_name = "UNetDS"
        model = unet_regr.UNetDS
    elif "UNet" in model_file:
        model_name = "UNet"
        model = unet_regr.UNet
    elif "PersistenceModel" in model_file:
        model_name = "PersistenceModel"
        model = unet_regr.PersistenceModel
    else:
        raise NotImplementedError("Model not found")
    return model, model_name
Download .txt
gitextract_3pkgi8mx/

├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── README.md
├── calc_metrics_test_set.py
├── create_datasets.py
├── metric/
│   ├── __init__.py
│   ├── confusionmatrix.py
│   ├── iou.py
│   ├── metric.py
│   └── precipitation_metrics.py
├── models/
│   ├── SmaAt_UNet.py
│   ├── __init__.py
│   ├── layers.py
│   ├── regression_lightning.py
│   ├── unet_parts.py
│   ├── unet_parts_depthwise_separable.py
│   └── unet_precip_regression_lightning.py
├── plot_examples.ipynb
├── pyproject.toml
├── requirements-dev.txt
├── requirements-full.txt
├── requirements.txt
├── root.py
├── train_SmaAtUNet.py
├── train_precip_lightning.py
└── utils/
    ├── __init__.py
    ├── data_loader_precip.py
    ├── dataset_VOC.py
    ├── dataset_precip.py
    ├── formatting.py
    └── model_classes.py
Download .txt
SYMBOL INDEX (140 symbols across 19 files)

FILE: calc_metrics_test_set.py
  function parse_args (line 19) | def parse_args():
  function convert_tensors_to_python (line 38) | def convert_tensors_to_python(obj):
  function save_metrics (line 51) | def save_metrics(results, file_path, file_type="json"):
  function run_experiments (line 75) | def run_experiments(model_folder, data_file, threshold=0.5, denormalize=...
  function plot_metrics (line 124) | def plot_metrics(results, save_path=""):
  function load_metrics_from_file (line 164) | def load_metrics_from_file(file_path, file_type):
  function main (line 202) | def main():

FILE: create_datasets.py
  function create_dataset (line 8) | def create_dataset(input_length: int, image_ahead: int, rain_amount_thre...

FILE: metric/confusionmatrix.py
  class ConfusionMatrix (line 6) | class ConfusionMatrix(metric.Metric):
    method __init__ (line 16) | def __init__(self, num_classes, normalized=False):
    method reset (line 24) | def reset(self):
    method add (line 27) | def add(self, predicted, target):
    method value (line 73) | def value(self):

FILE: metric/iou.py
  class IoU (line 7) | class IoU(metric.Metric):
    method __init__ (line 22) | def __init__(self, num_classes, normalized=False, ignore_index=None):
    method reset (line 36) | def reset(self):
    method add (line 39) | def add(self, predicted, target):
    method value (line 64) | def value(self):

FILE: metric/metric.py
  class Metric (line 4) | class Metric(Protocol):
    method reset (line 7) | def reset(self):
    method add (line 10) | def add(self, predicted, target):
    method value (line 13) | def value(self):

FILE: metric/precipitation_metrics.py
  class PrecipitationMetrics (line 6) | class PrecipitationMetrics(Metric):
    method __init__ (line 12) | def __init__(self, threshold=0.5, denormalize=True, dist_sync_on_step=...
    method update (line 37) | def update(self, preds, target):
    method compute (line 97) | def compute(self):

FILE: models/SmaAt_UNet.py
  class SmaAt_UNet (line 7) | class SmaAt_UNet(nn.Module):
    method __init__ (line 8) | def __init__(
    method forward (line 41) | def forward(self, x):

FILE: models/layers.py
  class DepthToSpace (line 7) | class DepthToSpace(nn.Module):
    method __init__ (line 8) | def __init__(self, block_size):
    method forward (line 12) | def forward(self, x):
  class SpaceToDepth (line 20) | class SpaceToDepth(nn.Module):
    method __init__ (line 22) | def __init__(self, block_size):
    method forward (line 26) | def forward(self, x):
  class DepthwiseSeparableConv (line 34) | class DepthwiseSeparableConv(nn.Module):
    method __init__ (line 35) | def __init__(self, in_channels, output_channels, kernel_size, padding=...
    method forward (line 47) | def forward(self, x):
  class DoubleDense (line 53) | class DoubleDense(nn.Module):
    method __init__ (line 54) | def __init__(self, in_channels, hidden_neurons, output_channels):
    method forward (line 60) | def forward(self, x):
  class DoubleDSConv (line 67) | class DoubleDSConv(nn.Module):
    method __init__ (line 70) | def __init__(self, in_channels, out_channels):
    method forward (line 81) | def forward(self, x):
  class Flatten (line 85) | class Flatten(nn.Module):
    method forward (line 86) | def forward(self, x):
  class ChannelAttention (line 90) | class ChannelAttention(nn.Module):
    method __init__ (line 91) | def __init__(self, input_channels, reduction_ratio=16):
    method forward (line 105) | def forward(self, x):
  class SpatialAttention (line 114) | class SpatialAttention(nn.Module):
    method __init__ (line 115) | def __init__(self, kernel_size=7):
    method forward (line 122) | def forward(self, x):
  class CBAM (line 132) | class CBAM(nn.Module):
    method __init__ (line 133) | def __init__(self, input_channels, reduction_ratio=16, kernel_size=7):
    method forward (line 138) | def forward(self, x):

FILE: models/regression_lightning.py
  class UNetBase (line 12) | class UNetBase(pl.LightningModule):
    method add_model_specific_args (line 14) | def add_model_specific_args(parent_parser):
    method __init__ (line 31) | def __init__(self, hparams):
    method forward (line 44) | def forward(self, x):
    method configure_optimizers (line 47) | def configure_optimizers(self):
    method loss_func (line 57) | def loss_func(self, y_pred, y_true):
    method training_step (line 67) | def training_step(self, batch, batch_idx):
    method validation_step (line 79) | def validation_step(self, batch, batch_idx):
    method test_step (line 88) | def test_step(self, batch, batch_idx):
    method on_test_epoch_end (line 96) | def on_test_epoch_end(self):
    method on_train_epoch_end (line 106) | def on_train_epoch_end(self):
    method on_validation_epoch_end (line 113) | def on_validation_epoch_end(self):
  class PrecipRegressionBase (line 121) | class PrecipRegressionBase(UNetBase):
    method add_model_specific_args (line 123) | def add_model_specific_args(parent_parser):
    method __init__ (line 134) | def __init__(self, hparams):
    method prepare_data (line 141) | def prepare_data(self):
    method train_dataloader (line 177) | def train_dataloader(self):
    method val_dataloader (line 189) | def val_dataloader(self):
  class PersistenceModel (line 202) | class PersistenceModel(UNetBase):
    method forward (line 203) | def forward(self, x):

FILE: models/unet_parts.py
  class DoubleConv (line 8) | class DoubleConv(nn.Module):
    method __init__ (line 11) | def __init__(self, in_channels, out_channels, mid_channels=None):
    method forward (line 24) | def forward(self, x):
  class Down (line 28) | class Down(nn.Module):
    method __init__ (line 31) | def __init__(self, in_channels, out_channels):
    method forward (line 35) | def forward(self, x):
  class Up (line 39) | class Up(nn.Module):
    method __init__ (line 42) | def __init__(self, in_channels, out_channels, bilinear=True):
    method forward (line 53) | def forward(self, x1, x2):
  class OutConv (line 67) | class OutConv(nn.Module):
    method __init__ (line 68) | def __init__(self, in_channels, out_channels):
    method forward (line 72) | def forward(self, x):

FILE: models/unet_parts_depthwise_separable.py
  class DoubleConvDS (line 10) | class DoubleConvDS(nn.Module):
    method __init__ (line 13) | def __init__(self, in_channels, out_channels, mid_channels=None, kerne...
    method forward (line 38) | def forward(self, x):
  class DownDS (line 42) | class DownDS(nn.Module):
    method __init__ (line 45) | def __init__(self, in_channels, out_channels, kernels_per_layer=1):
    method forward (line 52) | def forward(self, x):
  class UpDS (line 56) | class UpDS(nn.Module):
    method __init__ (line 59) | def __init__(self, in_channels, out_channels, bilinear=True, kernels_p...
    method forward (line 75) | def forward(self, x1, x2):
  class OutConv (line 89) | class OutConv(nn.Module):
    method __init__ (line 90) | def __init__(self, in_channels, out_channels):
    method forward (line 94) | def forward(self, x):

FILE: models/unet_precip_regression_lightning.py
  class UNet (line 7) | class UNet(PrecipRegressionBase):
    method __init__ (line 8) | def __init__(self, hparams):
    method forward (line 27) | def forward(self, x):
  class UNetAttention (line 41) | class UNetAttention(PrecipRegressionBase):
    method __init__ (line 42) | def __init__(self, hparams):
    method forward (line 67) | def forward(self, x):
  class UNetDS (line 86) | class UNetDS(PrecipRegressionBase):
    method __init__ (line 87) | def __init__(self, hparams):
    method forward (line 107) | def forward(self, x):
  class UNetDSAttention (line 121) | class UNetDSAttention(PrecipRegressionBase):
    method __init__ (line 122) | def __init__(self, hparams):
    method forward (line 148) | def forward(self, x):
  class UNetDSAttention4CBAMs (line 167) | class UNetDSAttention4CBAMs(PrecipRegressionBase):
    method __init__ (line 168) | def __init__(self, hparams):
    method forward (line 193) | def forward(self, x):

FILE: train_SmaAtUNet.py
  function get_lr (line 18) | def get_lr(optimizer):
  function fit (line 23) | def fit(

FILE: train_precip_lightning.py
  function train_regression (line 15) | def train_regression(hparams, find_batch_size_automatically: bool = False):

FILE: utils/data_loader_precip.py
  function get_train_valid_loader (line 10) | def get_train_valid_loader(
  function get_test_loader (line 141) | def get_test_loader(

FILE: utils/dataset_VOC.py
  function get_pascal_labels (line 12) | def get_pascal_labels():
  function decode_segmap (line 44) | def decode_segmap(label_mask, plot=False):
  class VOCSegmentation (line 73) | class VOCSegmentation(Dataset):
    method __init__ (line 98) | def __init__(self, root: Path, image_set="train", transformations=None...
    method __getitem__ (line 119) | def __getitem__(self, index):
    method __len__ (line 147) | def __len__(self):
    method apply_augmentations (line 150) | def apply_augmentations(self, img, target):

FILE: utils/dataset_precip.py
  class precipitation_maps_h5 (line 6) | class precipitation_maps_h5(Dataset):
    method __init__ (line 7) | def __init__(self, in_file, num_input_images, num_output_images, train...
    method __getitem__ (line 24) | def __getitem__(self, index):
    method __len__ (line 44) | def __len__(self):
  class precipitation_maps_oversampled_h5 (line 48) | class precipitation_maps_oversampled_h5(Dataset):
    method __init__ (line 49) | def __init__(self, in_file, num_input_images, num_output_images, train...
    method __getitem__ (line 63) | def __getitem__(self, index):
    method __len__ (line 79) | def __len__(self):
  class precipitation_maps_classification_h5 (line 83) | class precipitation_maps_classification_h5(Dataset):
    method __init__ (line 84) | def __init__(self, in_file, num_input_images, img_to_predict, train=Tr...
    method __getitem__ (line 102) | def __getitem__(self, index):
    method __len__ (line 126) | def __len__(self):

FILE: utils/formatting.py
  function make_metrics_str (line 5) | def make_metrics_str(metrics_dict):

FILE: utils/model_classes.py
  function get_model_class (line 5) | def get_model_class(model_file) -> tuple[type[pl.LightningModule], str]:
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (117K chars).
[
  {
    "path": ".gitignore",
    "chars": 88,
    "preview": ".idea\n.vscode\ndata\nruns\nlightning\ncheckpoints\nplots\n.ruff_cache\n.mypy_cache\n__pycache__\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 404,
    "preview": "repos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n    -   id: check-yaml\n    - "
  },
  {
    "path": ".python-version",
    "chars": 8,
    "preview": "3.10.11\n"
  },
  {
    "path": "README.md",
    "chars": 7347,
    "preview": "# SmaAt-UNet\nCode for the Paper \"SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture\" [Arxiv-"
  },
  {
    "path": "calc_metrics_test_set.py",
    "chars": 9178,
    "preview": "import argparse\nfrom argparse import Namespace\nimport json\nimport csv\nimport numpy as np\nimport os\nfrom pprint import pp"
  },
  {
    "path": "create_datasets.py",
    "chars": 4535,
    "preview": "import h5py\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom root import ROOT_DIR\n\n\ndef create_dataset(input_length: int, "
  },
  {
    "path": "metric/__init__.py",
    "chars": 141,
    "preview": "from .confusionmatrix import ConfusionMatrix\nfrom .iou import IoU\nfrom .metric import Metric\n\n__all__ = [\"ConfusionMatri"
  },
  {
    "path": "metric/confusionmatrix.py",
    "chars": 3553,
    "preview": "import numpy as np\nimport torch\nfrom metric import metric\n\n\nclass ConfusionMatrix(metric.Metric):\n    \"\"\"Constructs a co"
  },
  {
    "path": "metric/iou.py",
    "chars": 3717,
    "preview": "import numpy as np\nfrom metric import metric\nfrom metric.confusionmatrix import ConfusionMatrix\n\n\n# Taken from https://g"
  },
  {
    "path": "metric/metric.py",
    "chars": 214,
    "preview": "from typing import Protocol\n\n\nclass Metric(Protocol):\n    \"\"\"Base class for all metrics.\"\"\"\n\n    def reset(self):\n      "
  },
  {
    "path": "metric/precipitation_metrics.py",
    "chars": 6276,
    "preview": "import torch\nfrom torchmetrics import Metric\nimport numpy as np\n\n\nclass PrecipitationMetrics(Metric):\n    \"\"\"\n    A cust"
  },
  {
    "path": "models/SmaAt_UNet.py",
    "chars": 2272,
    "preview": "from torch import nn\nfrom models.unet_parts import OutConv\nfrom models.unet_parts_depthwise_separable import DoubleConvD"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/layers.py",
    "chars": 5146,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\n# Taken from https://discuss.pytorch.org/t/is-there-"
  },
  {
    "path": "models/regression_lightning.py",
    "chars": 7856,
    "preview": "import lightning.pytorch as pl\nimport torch\nfrom torch import nn, optim, Tensor\nfrom torch.utils.data import DataLoader\n"
  },
  {
    "path": "models/unet_parts.py",
    "chars": 2508,
    "preview": "\"\"\" Parts of the U-Net model \"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DoubleConv("
  },
  {
    "path": "models/unet_parts_depthwise_separable.py",
    "chars": 3213,
    "preview": "\"\"\" Parts of the U-Net model \"\"\"\n\n# Base model taken from: https://github.com/milesial/Pytorch-UNet\nimport torch\nimport "
  },
  {
    "path": "models/unet_precip_regression_lightning.py",
    "chars": 8572,
    "preview": "from models.unet_parts import Down, DoubleConv, Up, OutConv\nfrom models.unet_parts_depthwise_separable import DoubleConv"
  },
  {
    "path": "plot_examples.ipynb",
    "chars": 5999,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"o"
  },
  {
    "path": "pyproject.toml",
    "chars": 2709,
    "preview": "[project]\nname = \"smaat-unet\"\nversion = \"0.1.0\"\ndescription = \"Code for the paper `SmaAt-UNet: Precipitation Nowcasting "
  },
  {
    "path": "requirements-dev.txt",
    "chars": 2508,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv export --format=requirements-txt --group dev --no"
  },
  {
    "path": "requirements-full.txt",
    "chars": 3311,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv export --format=requirements-txt --all-groups --n"
  },
  {
    "path": "requirements.txt",
    "chars": 2293,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv export --format=requirements-txt --no-dev --no-ha"
  },
  {
    "path": "root.py",
    "chars": 59,
    "preview": "from pathlib import Path\n\nROOT_DIR = Path(__file__).parent\n"
  },
  {
    "path": "train_SmaAtUNet.py",
    "chars": 6396,
    "preview": "from typing import Optional\n\nfrom models.SmaAt_UNet import SmaAt_UNet\nimport torch\nfrom torch.utils.data import DataLoad"
  },
  {
    "path": "train_precip_lightning.py",
    "chars": 4214,
    "preview": "from root import ROOT_DIR\n\nimport lightning.pytorch as pl\nfrom lightning.pytorch.callbacks import (\n    ModelCheckpoint,"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/data_loader_precip.py",
    "chars": 6975,
    "preview": "import torch\nimport numpy as np\nfrom torchvision import transforms\nfrom torch.utils.data.sampler import SubsetRandomSamp"
  },
  {
    "path": "utils/dataset_VOC.py",
    "chars": 4963,
    "preview": "import random\nimport torch\nfrom torch.utils.data import Dataset\nfrom PIL import Image\nfrom pathlib import Path\nfrom torc"
  },
  {
    "path": "utils/dataset_precip.py",
    "chars": 4936,
    "preview": "from torch.utils.data import Dataset\nimport h5py\nimport numpy as np\n\n\nclass precipitation_maps_h5(Dataset):\n    def __in"
  },
  {
    "path": "utils/formatting.py",
    "chars": 434,
    "preview": "import torch\nimport numpy as np\n\n\ndef make_metrics_str(metrics_dict):\n    return \" | \".join([f\"{name}: {value.item() if "
  },
  {
    "path": "utils/model_classes.py",
    "chars": 1265,
    "preview": "from models import unet_precip_regression_lightning as unet_regr\nimport lightning.pytorch as pl\n\n\ndef get_model_class(mo"
  }
]

About this extraction

This page contains the full source code of the HansBambel/SmaAt-UNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (108.5 KB), approximately 29.7k tokens, and a symbol index with 140 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!