Repository: HansBambel/SmaAt-UNet Branch: master Commit: 3e48dcca3e88 Files: 32 Total size: 108.5 KB Directory structure: gitextract_3pkgi8mx/ ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── README.md ├── calc_metrics_test_set.py ├── create_datasets.py ├── metric/ │ ├── __init__.py │ ├── confusionmatrix.py │ ├── iou.py │ ├── metric.py │ └── precipitation_metrics.py ├── models/ │ ├── SmaAt_UNet.py │ ├── __init__.py │ ├── layers.py │ ├── regression_lightning.py │ ├── unet_parts.py │ ├── unet_parts_depthwise_separable.py │ └── unet_precip_regression_lightning.py ├── plot_examples.ipynb ├── pyproject.toml ├── requirements-dev.txt ├── requirements-full.txt ├── requirements.txt ├── root.py ├── train_SmaAtUNet.py ├── train_precip_lightning.py └── utils/ ├── __init__.py ├── data_loader_precip.py ├── dataset_VOC.py ├── dataset_precip.py ├── formatting.py └── model_classes.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea .vscode data runs lightning checkpoints plots .ruff_cache .mypy_cache __pycache__ ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit rev: 'v0.9.7' hooks: - id: ruff args: ["--fix"] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.15.0 hooks: - id: mypy ================================================ FILE: .python-version ================================================ 3.10.11 ================================================ FILE: README.md ================================================ # SmaAt-UNet Code for the Paper "SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture" [Arxiv-link](https://arxiv.org/abs/2007.04417), [Elsevier-link](https://www.sciencedirect.com/science/article/pii/S0167865521000556?via%3Dihub) Other relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting). ![SmaAt-UNet](SmaAt-UNet.png) The proposed SmaAt-UNet can be found in the model-folder under [SmaAt_UNet](models/SmaAt_UNet.py). **>>>IMPORTANT<<<** The original Code from the paper can be found in this branch: https://github.com/HansBambel/SmaAt-UNet/tree/snapshot-paper The current master branch has since upgraded packages and was refactored. Since the exact package-versions differ the experiments may not be 100% reproducible. If you have problems running the code, feel free to open an issue here on Github. ## Installing dependencies This project is using [uv](https://docs.astral.sh/uv/) as dependency management. Therefore, installing the required dependencies is as easy as this: ```shell uv sync --frozen ``` If you have Nvidia GPU(s) that support CUDA, and running the command above have not installed the torch with CUDA support, you can install the relevant dependencies with the following command: ```shell uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 ``` In any case a [`requirements.txt`](requirements.txt) file are generated from `uv.lock`, which implies that `uv` must be installed in order to perform the export: ```shell uv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt ``` A `requirements-dev.txt` file is also generated for development dependencies: ```shell uv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt ``` A `requirements-full.txt` file is also generated for all dependencies (including development and additional dependencies): ```shell uv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt ``` Alternatively, users can use the existing [`requirements.txt`](requirements.txt), [`requirements-dev.txt`](requirements-dev.txt) and [`requirements-full.txt`](requirements-full.txt) files. Basically, only the following requirements are needed for the training: ``` tqdm torch lightning matplotlib tensorboard torchsummary h5py numpy ``` Additional requirements are needed for the [testing](###Testing): ``` ipykernel ``` --- For the paper we used the [Lightning](https://github.com/Lightning-AI/lightning) -module (PL) which simplifies the training process and allows easy additions of loggers and checkpoint creations. In order to use PL we created the model [UNetDSAttention](models/unet_precip_regression_lightning.py) whose parent inherits from the pl.LightningModule. This model is the same as the pure PyTorch SmaAt-UNet implementation with the added PL functions. ### Training An example [training script](train_SmaAtUNet.py) is given for a classification task (PascalVOC). For training on the precipitation task we used the [train_precip_lightning.py](train_precip_lightning.py) file. The training will place a checkpoint file for every model in the `default_save_path` `lightning/precip_regression`. After finishing training place the best models (probably the ones with the lowest validation loss) that you want to compare in another folder in `checkpoints/comparison`. ### Testing The script [calc_metrics_test_set.py](calc_metrics_test_set.py) will use all the models placed in `checkpoints/comparison` to calculate the MSEs and other metrics such as Precision, Recall, Accuracy, F1, CSI, FAR, HSS. The results will get saved in a json in the same folder as the models. The metrics of the persistence model (for now) are only calculated using the script [test_precip_lightning.py](test_precip_lightning.py). This script will also use all models in that folder and calculate the test-losses for the models in addition to the persistence model. It will get handled by [this issue](https://github.com/HansBambel/SmaAt-UNet/issues/28). ### Plots Example code for creating similar plots as in the paper can be found in [plot_examples.ipynb](plot_examples.ipynb). You might need to make your kernel discoverable so that Jupyter can detect it. To do so, run the following command in your terminal: ```shell uv run ipython kernel install --user --env VIRTUAL_ENV=$(pwd)/.venv --name=smaat_unet --diplay-name="SmaAt-UNet Kernel" ``` ### Precipitation dataset The dataset consists of precipitation maps in 5-minute intervals from 2016-2019 resulting in about 420,000 images. The dataset is based on radar precipitation maps from the [The Royal Netherlands Meteorological Institute (KNMI)](https://www.knmi.nl/over-het-knmi/about). The original images were cropped as can be seen in the example below: ![Precip cutout](Precipitation%20map%20Cutout.png) If you are interested in the dataset that we used please write an e-mail to: k.trebing@alumni.maastrichtuniversity.nl and s.mehrkanoon@uu.nl The 50% dataset has 4GB in size and the 20% dataset has 16.5GB in size. Use the [create_dataset.py](create_datasets.py) to create the two datasets used from the original dataset. The dataset is already normalized using a [Min-Max normalization](https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization)). In order to revert this you need to multiply the images by 47.83; this results in the images showing the mm/5min. ### Citation ``` @article{TREBING2021, title = {SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture}, journal = {Pattern Recognition Letters}, year = {2021}, issn = {0167-8655}, doi = {https://doi.org/10.1016/j.patrec.2021.01.036}, url = {https://www.sciencedirect.com/science/article/pii/S0167865521000556}, author = {Kevin Trebing and Tomasz Staǹczyk and Siamak Mehrkanoon}, keywords = {Domain adaptation, neural networks, kernel methods, coupling regularization}, abstract = {Weather forecasting is dominated by numerical weather prediction that tries to model accurately the physical properties of the atmosphere. A downside of numerical weather prediction is that it is lacking the ability for short-term forecasts using the latest available information. By using a data-driven neural network approach we show that it is possible to produce an accurate precipitation nowcast. To this end, we propose SmaAt-UNet, an efficient convolutional neural networks-based on the well known UNet architecture equipped with attention modules and depthwise-separable convolutions. We evaluate our approaches on a real-life datasets using precipitation maps from the region of the Netherlands and binary images of cloud coverage of France. The experimental results show that in terms of prediction performance, the proposed model is comparable to other examined models while only using a quarter of the trainable parameters.} } ``` ### Other relevant publications Your project might also benefit from exploring our other relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting). Please consider citing them if you find them useful. ================================================ FILE: calc_metrics_test_set.py ================================================ import argparse from argparse import Namespace import json import csv import numpy as np import os from pprint import pprint import torch from tqdm import tqdm import lightning.pytorch as pl import matplotlib.pyplot as plt from metric.precipitation_metrics import PrecipitationMetrics from models.unet_precip_regression_lightning import PersistenceModel from root import ROOT_DIR from utils import dataset_precip, model_classes def parse_args(): parser = argparse.ArgumentParser(description="Calculate metrics for precipitation models") parser.add_argument("--model-folder", type=str, default=None, help="Path to model folder (default: ROOT_DIR/checkpoints/comparison)") parser.add_argument("--threshold", type=float, default=0.5, help="Precipitation threshold)") parser.add_argument("--normalized", action="store_false", dest="denormalize", default=True, help="Calculate metrics for normalized data") parser.add_argument("--load-metrics", action="store_true", help="Load existing metrics instead of running experiments") parser.add_argument("--file-type", type=str, default="txt", choices=["json", "txt", "csv"], help="Output file format (json, txt, or csv)") parser.add_argument("--save-dir", type=str, default=None, help="Path to save the metrics") parser.add_argument("--plot", action="store_true", help="Plot the metrics") return parser.parse_args() def convert_tensors_to_python(obj): """Convert PyTorch tensors to Python native types recursively.""" if isinstance(obj, torch.Tensor): # Convert tensor to float/int, handling nan values value = obj.item() if obj.numel() == 1 else obj.tolist() return float('nan') if isinstance(value, float) and np.isnan(value) else value elif isinstance(obj, dict): return {key: convert_tensors_to_python(value) for key, value in obj.items()} elif isinstance(obj, list): return [convert_tensors_to_python(value) for value in obj] return obj def save_metrics(results, file_path, file_type="json"): """ Save metrics to a file in the specified format. Args: results (dict): Dictionary containing model metrics file_path (Path): Path to save the file file_type (str): File format - "json", "txt", or "csv" """ with open(file_path, "w") as f: if file_type == "csv": writer = csv.writer(f) # Write header row if results: first_model = next(iter(results.values())) writer.writerow(["Model"] + list(first_model.keys())) # Write data rows for model_name, metrics in results.items(): writer.writerow([model_name] + list(metrics.values())) else: # For json and txt formats, use json.dump json.dump(results, f, indent=4) def run_experiments(model_folder, data_file, threshold=0.5, denormalize=True): """ Run test experiments for all models in the model folder. Args: model_folder (Path): Path to the model folder data_file (Path): Path to the data file threshold (float): Threshold for the precipitation Returns: dict: Dictionary containing model metrics """ results = {} # Load the dataset dataset = dataset_precip.precipitation_maps_oversampled_h5( in_file=data_file, num_input_images=12, num_output_images=6, train=False ) # Create dataloader test_dl = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False, pin_memory=True, persistent_workers=True) # Create trainer with device specification trainer = pl.Trainer(logger=False, enable_checkpointing=False) # Find all models in the model folder models = ["PersistenceModel"] + [f for f in os.listdir(model_folder) if f.endswith(".ckpt")] if not models: raise ValueError("No checkpoint files found in the model folder.") for model_file in tqdm(models, desc="Models", leave=True): model, model_name = model_classes.get_model_class(model_file) print(f"Loading model: {model_name}") if model_file == "PersistenceModel": loaded_model = model(Namespace()) else: loaded_model = model.load_from_checkpoint(os.path.join(model_folder, model_file)) # Override the test metrics, required for testing with normalized data loaded_model.test_metrics = PrecipitationMetrics(threshold=threshold, denormalize=denormalize) results[model_name] = trainer.test(model=loaded_model, dataloaders=[test_dl])[0] return results def plot_metrics(results, save_path=""): """ Plot all metrics from the results dictionary. Args: results (dict): Dictionary containing metrics for different models save_path (str): Path to save the plots """ model_names = list(results.keys()) metrics = list(results[model_names[0]].keys()) for metric_name in metrics: # Get values, replacing NaN with None to skip them in the plot values = [] valid_names = [] for model_name in model_names: val = results[model_name][metric_name] # Skip NaN values if not (isinstance(val, float) and np.isnan(val)): values.append(val) valid_names.append(model_name) if len(valid_names) == 0: continue plt.figure(figsize=(10, 6)) plt.bar(valid_names, values) plt.xticks(rotation=45) plt.xlabel("Models") plt.ylabel(f"{metric_name.upper()}") plt.title(f"Comparison of different models - {metric_name.upper()}") plt.tight_layout() if save_path: plt.savefig(f"{save_path}/{metric_name}.png") plt.show() def load_metrics_from_file(file_path, file_type): """ Load metrics from a file in the specified format. Args: file_path (Path): Path to the metrics file file_type (str): File format - "json", "txt", or "csv" Returns: dict: Dictionary containing model metrics """ results = {} with open(file_path) as f: if file_type == "json" or file_type == "txt": # Both json and txt files are saved in JSON format in this application results = json.load(f) elif file_type == "csv": reader = csv.reader(f) headers = next(reader) # Get header row for row in reader: if len(row) > 0: model_name = row[0] # Create dictionary of metrics for this model model_metrics = {} for i in range(1, len(headers)): if i < len(row): # Try to convert to float if possible try: value = float(row[i]) except (ValueError, TypeError): value = row[i] model_metrics[headers[i]] = value results[model_name] = model_metrics return results def main(): # Parse command line arguments args = parse_args() # Variables from command line arguments load_metrics = args.load_metrics threshold = args.threshold file_type = args.file_type denormalize = args.denormalize # Define paths model_folder = ROOT_DIR / "checkpoints" / "comparison" if args.model_folder is None else args.model_folder data_file = ROOT_DIR / "data" / "precipitation" / f"train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_{int(threshold*100)}.h5" save_dir = model_folder / f"calc_metrics_test_set_results_{'denorm' if denormalize else 'norm'}" if args.save_dir is None else args.save_dir os.makedirs(save_dir, exist_ok=True) file_name = f"model_metrics_{threshold}_{'denorm' if denormalize else 'norm'}.{file_type}" file_path = save_dir / file_name # Load metrics if available if load_metrics: print(f"Loading metrics from {file_path}") results = load_metrics_from_file(file_path, file_type) else: # Run experiments print(f"Running experiments with {threshold} threshold and {'denormalized' if denormalize else 'normalized'} data") results = run_experiments(model_folder, data_file, threshold=threshold, denormalize=denormalize) # Convert tensors to Python native types results = convert_tensors_to_python(results) save_metrics(results, file_path, file_type) # Display results print(f"Metrics saved to {file_path}") pprint(results) # Plot metrics if requested if args.plot: plots_dir = save_dir / "plots" os.makedirs(plots_dir, exist_ok=True) plot_metrics(results, plots_dir) print(f"Plots saved to {plots_dir}") if __name__ == "__main__": main() ================================================ FILE: create_datasets.py ================================================ import h5py import numpy as np from tqdm import tqdm from root import ROOT_DIR def create_dataset(input_length: int, image_ahead: int, rain_amount_thresh: float): """Create a dataset that has target images containing at least `rain_amount_thresh` (percent) of rain.""" precipitation_folder = ROOT_DIR / "data" / "precipitation" with h5py.File( precipitation_folder / "RAD_NL25_RAC_5min_train_test_2016-2019.h5", "r", ) as orig_f: train_images = orig_f["train"]["images"] train_timestamps = orig_f["train"]["timestamps"] test_images = orig_f["test"]["images"] test_timestamps = orig_f["test"]["timestamps"] print("Train shape", train_images.shape) print("Test shape", test_images.shape) imgSize = train_images.shape[1] num_pixels = imgSize * imgSize filename = ( precipitation_folder / f"train_test_2016-2019_input-length_{input_length}_img-" f"ahead_{image_ahead}_rain-threshold_{int(rain_amount_thresh * 100)}.h5" ) with h5py.File(filename, "w") as f: train_set = f.create_group("train") test_set = f.create_group("test") train_image_dataset = train_set.create_dataset( "images", shape=(1, input_length + image_ahead, imgSize, imgSize), maxshape=(None, input_length + image_ahead, imgSize, imgSize), dtype="float32", compression="gzip", compression_opts=9, ) train_timestamp_dataset = train_set.create_dataset( "timestamps", shape=(1, input_length + image_ahead, 1), maxshape=(None, input_length + image_ahead, 1), dtype=h5py.special_dtype(vlen=str), compression="gzip", compression_opts=9, ) test_image_dataset = test_set.create_dataset( "images", shape=(1, input_length + image_ahead, imgSize, imgSize), maxshape=(None, input_length + image_ahead, imgSize, imgSize), dtype="float32", compression="gzip", compression_opts=9, ) test_timestamp_dataset = test_set.create_dataset( "timestamps", shape=(1, input_length + image_ahead, 1), maxshape=(None, input_length + image_ahead, 1), dtype=h5py.special_dtype(vlen=str), compression="gzip", compression_opts=9, ) origin = [[train_images, train_timestamps], [test_images, test_timestamps]] datasets = [ [train_image_dataset, train_timestamp_dataset], [test_image_dataset, test_timestamp_dataset], ] for origin_id, (images, timestamps) in enumerate(origin): image_dataset, timestamp_dataset = datasets[origin_id] # Pre-calculate all valid indices that meet the rain threshold sequence_length = input_length + image_ahead valid_indices = [] # Use vectorized operations to find valid indices for i in tqdm(range(sequence_length, len(images)), desc="Finding valid indices"): rain_pixels = np.sum(images[i] > 0) if rain_pixels >= num_pixels * rain_amount_thresh: valid_indices.append(i) # Pre-allocate the final dataset size total_sequences = len(valid_indices) image_dataset.resize(total_sequences, axis=0) timestamp_dataset.resize(total_sequences, axis=0) # Batch process the sequences for idx, i in enumerate(tqdm(valid_indices, desc="Creating sequences")): imgs = images[i - sequence_length : i] timestamps_img = timestamps[i - sequence_length : i] image_dataset[idx] = imgs timestamp_dataset[idx] = timestamps_img if __name__ == "__main__": print("Creating dataset with at least 20% of rain pixel in target image") create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.2) print("Creating dataset with at least 50% of rain pixel in target image") create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.5) ================================================ FILE: metric/__init__.py ================================================ from .confusionmatrix import ConfusionMatrix from .iou import IoU from .metric import Metric __all__ = ["ConfusionMatrix", "IoU", "Metric"] ================================================ FILE: metric/confusionmatrix.py ================================================ import numpy as np import torch from metric import metric class ConfusionMatrix(metric.Metric): """Constructs a confusion matrix for a multi-class classification problems. Does not support multi-label, multi-class problems. Keyword arguments: - num_classes (int): number of classes in the classification problem. - normalized (boolean, optional): Determines whether or not the confusion matrix is normalized or not. Default: False. Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py """ def __init__(self, num_classes, normalized=False): super().__init__() self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) self.normalized = normalized self.num_classes = num_classes self.reset() def reset(self): self.conf.fill(0) def add(self, predicted, target): """Computes the confusion matrix The shape of the confusion matrix is K x K, where K is the number of classes. Keyword arguments: - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of predicted scores obtained from the model for N examples and K classes, or an N-tensor/array of integer values between 0 and K-1. - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of ground-truth classes for N examples and K classes, or an N-tensor/array of integer values between 0 and K-1. """ # If target and/or predicted are tensors, convert them to numpy arrays if torch.is_tensor(predicted): predicted = predicted.cpu().numpy() if torch.is_tensor(target): target = target.cpu().numpy() assert predicted.shape[0] == target.shape[0], "number of targets and predicted outputs do not match" if np.ndim(predicted) != 1: assert ( predicted.shape[1] == self.num_classes ), "number of predictions does not match size of confusion matrix" predicted = np.argmax(predicted, 1) else: assert (predicted.max() < self.num_classes) and ( predicted.min() >= 0 ), "predicted values are not between 0 and k-1" if np.ndim(target) != 1: assert target.shape[1] == self.num_classes, "Onehot target does not match size of confusion matrix" assert (target >= 0).all() and (target <= 1).all(), "in one-hot encoding, target values should be 0 or 1" assert (target.sum(1) == 1).all(), "multi-label setting is not supported" target = np.argmax(target, 1) else: assert (target.max() < self.num_classes) and (target.min() >= 0), "target values are not between 0 and k-1" # hack for bincounting 2 arrays together x = predicted + self.num_classes * target bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes**2) assert bincount_2d.size == self.num_classes**2 conf = bincount_2d.reshape((self.num_classes, self.num_classes)) self.conf += conf def value(self): """ Returns: Confustion matrix of K rows and K columns, where rows corresponds to ground-truth targets and columns corresponds to predicted targets. """ if self.normalized: conf = self.conf.astype(np.float32) return conf / conf.sum(1).clip(min=1e-12)[:, None] else: return self.conf ================================================ FILE: metric/iou.py ================================================ import numpy as np from metric import metric from metric.confusionmatrix import ConfusionMatrix # Taken from https://github.com/davidtvs/PyTorch-ENet/tree/master/metric class IoU(metric.Metric): """Computes the intersection over union (IoU) per class and corresponding mean (mIoU). Intersection over union (IoU) is a common evaluation metric for semantic segmentation. The predictions are first accumulated in a confusion matrix and the IoU is computed from it as follows: IoU = true_positive / (true_positive + false_positive + false_negative). Keyword arguments: - num_classes (int): number of classes in the classification problem - normalized (boolean, optional): Determines whether or not the confusion matrix is normalized or not. Default: False. - ignore_index (int or iterable, optional): Index of the classes to ignore when computing the IoU. Can be an int, or any iterable of ints. """ def __init__(self, num_classes, normalized=False, ignore_index=None): super().__init__() self.conf_metric = ConfusionMatrix(num_classes, normalized) if ignore_index is None: self.ignore_index = None elif isinstance(ignore_index, int): self.ignore_index = (ignore_index,) else: try: self.ignore_index = tuple(ignore_index) except TypeError as err: raise ValueError("'ignore_index' must be an int or iterable") from err def reset(self): self.conf_metric.reset() def add(self, predicted, target): """Adds the predicted and target pair to the IoU metric. Keyword arguments: - predicted (Tensor): Can be a (N, K, H, W) tensor of predicted scores obtained from the model for N examples and K classes, or (N, H, W) tensor of integer values between 0 and K-1. - target (Tensor): Can be a (N, K, H, W) tensor of target scores for N examples and K classes, or (N, H, W) tensor of integer values between 0 and K-1. """ # Dimensions check assert predicted.size(0) == target.size(0), "number of targets and predicted outputs do not match" assert ( predicted.dim() == 3 or predicted.dim() == 4 ), "predictions must be of dimension (N, H, W) or (N, K, H, W)" assert target.dim() == 3 or target.dim() == 4, "targets must be of dimension (N, H, W) or (N, K, H, W)" # If the tensor is in categorical format convert it to integer format if predicted.dim() == 4: _, predicted = predicted.max(1) if target.dim() == 4: _, target = target.max(1) self.conf_metric.add(predicted.view(-1), target.view(-1)) def value(self): """Computes the IoU and mean IoU. The mean computation ignores NaN elements of the IoU array. Returns: Tuple: (IoU, mIoU). The first output is the per class IoU, for K classes it's numpy.ndarray with K elements. The second output, is the mean IoU. """ conf_matrix = self.conf_metric.value() if self.ignore_index is not None: conf_matrix[:, self.ignore_index] = 0 conf_matrix[self.ignore_index, :] = 0 true_positive = np.diag(conf_matrix) false_positive = np.sum(conf_matrix, 0) - true_positive false_negative = np.sum(conf_matrix, 1) - true_positive # Just in case we get a division by 0, ignore/hide the error with np.errstate(divide="ignore", invalid="ignore"): iou = true_positive / (true_positive + false_positive + false_negative) return iou, np.nanmean(iou) ================================================ FILE: metric/metric.py ================================================ from typing import Protocol class Metric(Protocol): """Base class for all metrics.""" def reset(self): pass def add(self, predicted, target): pass def value(self): pass ================================================ FILE: metric/precipitation_metrics.py ================================================ import torch from torchmetrics import Metric import numpy as np class PrecipitationMetrics(Metric): """ A custom metric for precipitation forecasting that computes both regression metrics (MSE) and classification metrics (precision, recall, F1, etc.) based on a threshold. """ def __init__(self, threshold=0.5, denormalize=True, dist_sync_on_step=False): """ Args: threshold: Threshold to convert continuous predictions into binary masks (in mm/h) denormalize: If True, applies the factor to undo normalization dist_sync_on_step: Synchronize metric state across processes at each step """ super().__init__(dist_sync_on_step=dist_sync_on_step) self.threshold = threshold self.denormalize = denormalize self.factor = 47.83 # Factor to denormalize predictions (mm/h) # Add states for regression metrics self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total_loss_denorm", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total_samples", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total_pixels", default=torch.tensor(0), dist_reduce_fx="sum") # Add states for classification metrics self.add_state("total_tp", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total_fp", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total_tn", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total_fn", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds, target): """ Update the metric states with new predictions and targets. Args: preds: Model predictions (normalized) target: Ground truth (normalized) """ # Check for NaN values if torch.isnan(preds).any() or torch.isnan(target).any(): print("Warning: NaN values detected in predictions or targets") return # Make sure preds and target have the same shape if preds.shape != target.shape: if len(preds.shape) < len(target.shape): preds = preds.unsqueeze(0) elif len(preds.shape) > len(target.shape): preds = preds.squeeze() # If squeezing made it too small, add dimension back if len(preds.shape) < len(target.shape): preds = preds.unsqueeze(0) # Calculate MSE loss (normalized) batch_size = target.size(0) loss = torch.nn.functional.mse_loss(preds, target, reduction="sum") / batch_size self.total_loss += loss self.total_samples += batch_size # Use batch_size instead of 1 self.total_pixels += target.numel() # Denormalize if needed if self.denormalize: preds_updated = preds * self.factor target_updated = target * self.factor # Calculate denormalized MSE loss loss_denorm = torch.nn.functional.mse_loss(preds_updated, target_updated, reduction="sum") / batch_size self.total_loss_denorm += loss_denorm else: preds_updated = preds target_updated = target # Convert to mm/h for classification metrics (multiply by 12 for 5min to hourly rate) preds_hourly = preds_updated * 12 target_hourly = target_updated * 12 # Apply threshold to get binary masks preds_mask = preds_hourly > self.threshold target_mask = target_hourly > self.threshold # Compute confusion matrix confusion = target_mask.view(-1) * 2 + preds_mask.view(-1) bincount = torch.bincount(confusion, minlength=4) # Update confusion matrix states self.total_tn += bincount[0] self.total_fp += bincount[1] self.total_fn += bincount[2] self.total_tp += bincount[3] def compute(self): """ Compute the final metrics from the accumulated states. Returns: A dictionary containing all metrics. """ # Compute regression metrics mse = self.total_loss / self.total_samples mse_denorm = self.total_loss_denorm / self.total_samples if self.denormalize else torch.tensor(float('nan')) mse_pixel = self.total_loss_denorm / self.total_pixels if self.denormalize else torch.tensor(float('nan')) # Compute classification metrics precision = self.total_tp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan')) recall = self.total_tp / (self.total_tp + self.total_fn) if (self.total_tp + self.total_fn) > 0 else torch.tensor(float('nan')) accuracy = (self.total_tp + self.total_tn) / (self.total_tp + self.total_tn + self.total_fp + self.total_fn) # F1 score f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(float('nan')) # Critical Success Index (CSI) or Threat Score csi = self.total_tp / (self.total_tp + self.total_fn + self.total_fp) if (self.total_tp + self.total_fn + self.total_fp) > 0 else torch.tensor(float('nan')) # False Alarm Ratio (FAR) far = self.total_fp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan')) # Heidke Skill Score (HSS) denom = ((self.total_tp + self.total_fn) * (self.total_fn + self.total_tn) + (self.total_tp + self.total_fp) * (self.total_fp + self.total_tn)) hss = ((self.total_tp * self.total_tn) - (self.total_fn * self.total_fp)) / denom if denom > 0 else torch.tensor(float('nan')) return { "mse": mse, "mse_denorm": mse_denorm, "mse_pixel": mse_pixel, "precision": precision, "recall": recall, "accuracy": accuracy, "f1": f1, "csi": csi, "far": far, "hss": hss } ================================================ FILE: models/SmaAt_UNet.py ================================================ from torch import nn from models.unet_parts import OutConv from models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS from models.layers import CBAM class SmaAt_UNet(nn.Module): def __init__( self, n_channels, n_classes, kernels_per_layer=2, bilinear=True, reduction_ratio=16, ): super().__init__() self.n_channels = n_channels self.n_classes = n_classes kernels_per_layer = kernels_per_layer self.bilinear = bilinear reduction_ratio = reduction_ratio self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer) self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio) self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio) self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio) self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio) factor = 2 if self.bilinear else 1 self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer) self.outc = OutConv(64, self.n_classes) def forward(self, x): x1 = self.inc(x) x1Att = self.cbam1(x1) x2 = self.down1(x1) x2Att = self.cbam2(x2) x3 = self.down2(x2) x3Att = self.cbam3(x3) x4 = self.down3(x3) x4Att = self.cbam4(x4) x5 = self.down4(x4) x5Att = self.cbam5(x5) x = self.up1(x5Att, x4Att) x = self.up2(x, x3Att) x = self.up3(x, x2Att) x = self.up4(x, x1Att) logits = self.outc(x) return logits ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/layers.py ================================================ import torch from torch import nn import torch.nn.functional as F # Taken from https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14 class DepthToSpace(nn.Module): def __init__(self, block_size): super().__init__() self.bs = block_size def forward(self, x): N, C, H, W = x.size() x = x.view(N, self.bs, self.bs, C // (self.bs**2), H, W) # (N, bs, bs, C//bs^2, H, W) x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) x = x.view(N, C // (self.bs**2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) return x class SpaceToDepth(nn.Module): # Expects the following shape: Batch, Channel, Height, Width def __init__(self, block_size): super().__init__() self.bs = block_size def forward(self, x): N, C, H, W = x.size() x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) x = x.view(N, C * (self.bs**2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) return x class DepthwiseSeparableConv(nn.Module): def __init__(self, in_channels, output_channels, kernel_size, padding=0, kernels_per_layer=1): super().__init__() # In Tensorflow DepthwiseConv2D has depth_multiplier instead of kernels_per_layer self.depthwise = nn.Conv2d( in_channels, in_channels * kernels_per_layer, kernel_size=kernel_size, padding=padding, groups=in_channels, ) self.pointwise = nn.Conv2d(in_channels * kernels_per_layer, output_channels, kernel_size=1) def forward(self, x): x = self.depthwise(x) x = self.pointwise(x) return x class DoubleDense(nn.Module): def __init__(self, in_channels, hidden_neurons, output_channels): super().__init__() self.dense1 = nn.Linear(in_channels, out_features=hidden_neurons) self.dense2 = nn.Linear(in_features=hidden_neurons, out_features=hidden_neurons // 2) self.dense3 = nn.Linear(in_features=hidden_neurons // 2, out_features=output_channels) def forward(self, x): out = F.relu(self.dense1(x.view(x.size(0), -1))) out = F.relu(self.dense2(out)) out = self.dense3(out) return out class DoubleDSConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels): super().__init__() self.double_ds_conv = nn.Sequential( DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), DepthwiseSeparableConv(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_ds_conv(x) class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelAttention(nn.Module): def __init__(self, input_channels, reduction_ratio=16): super().__init__() self.input_channels = input_channels self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) # https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py # uses Convolutions instead of Linear self.MLP = nn.Sequential( Flatten(), nn.Linear(input_channels, input_channels // reduction_ratio), nn.ReLU(), nn.Linear(input_channels // reduction_ratio, input_channels), ) def forward(self, x): # Take the input and apply average and max pooling avg_values = self.avg_pool(x) max_values = self.max_pool(x) out = self.MLP(avg_values) + self.MLP(max_values) scale = x * torch.sigmoid(out).unsqueeze(2).unsqueeze(3).expand_as(x) return scale class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() assert kernel_size in (3, 7), "kernel size must be 3 or 7" padding = 3 if kernel_size == 7 else 1 self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False) self.bn = nn.BatchNorm2d(1) def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) out = torch.cat([avg_out, max_out], dim=1) out = self.conv(out) out = self.bn(out) scale = x * torch.sigmoid(out) return scale class CBAM(nn.Module): def __init__(self, input_channels, reduction_ratio=16, kernel_size=7): super().__init__() self.channel_att = ChannelAttention(input_channels, reduction_ratio=reduction_ratio) self.spatial_att = SpatialAttention(kernel_size=kernel_size) def forward(self, x): out = self.channel_att(x) out = self.spatial_att(out) return out ================================================ FILE: models/regression_lightning.py ================================================ import lightning.pytorch as pl import torch from torch import nn, optim, Tensor from torch.utils.data import DataLoader from torch.utils.data.sampler import SubsetRandomSampler from utils import dataset_precip import argparse import numpy as np from metric.precipitation_metrics import PrecipitationMetrics from utils.formatting import make_metrics_str class UNetBase(pl.LightningModule): @staticmethod def add_model_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( "--model", type=str, default="UNet", choices=["UNet", "UNetDS", "UNetAttention", "UNetDSAttention", "PersistenceModel"], ) parser.add_argument("--n_channels", type=int, default=12) parser.add_argument("--n_classes", type=int, default=1) parser.add_argument("--kernels_per_layer", type=int, default=1) parser.add_argument("--bilinear", type=bool, default=True) parser.add_argument("--reduction_ratio", type=int, default=16) parser.add_argument("--lr_patience", type=int, default=5) parser.add_argument("--threshold", type=float, default=0.5) return parser def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) self.train_metrics = PrecipitationMetrics( threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5 ) self.val_metrics = PrecipitationMetrics( threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5 ) self.test_metrics = PrecipitationMetrics( threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5 ) def forward(self, x): pass def configure_optimizers(self): opt = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) scheduler = { "scheduler": optim.lr_scheduler.ReduceLROnPlateau( opt, mode="min", factor=0.1, patience=self.hparams.lr_patience ), "monitor": "val_loss", # Default: val_loss } return [opt], [scheduler] def loss_func(self, y_pred, y_true): # Ensure consistent shapes before computing loss if y_pred.dim() > y_true.dim(): y_pred = y_pred.squeeze(1) elif y_true.dim() > y_pred.dim(): y_pred = y_pred.unsqueeze(1) # reduction="mean" is average of every pixel, but I want average of image return nn.functional.mse_loss(y_pred, y_true, reduction="sum") / y_true.size(0) def training_step(self, batch, batch_idx): x, y = batch y_pred = self(x) loss = self.loss_func(y_pred, y) self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) # Update training metrics with detached tensors self.train_metrics.update(y_pred.detach(), y.detach()) return loss def validation_step(self, batch, batch_idx): x, y = batch y_pred = self(x) loss = self.loss_func(y_pred, y) self.log("val_loss", loss, prog_bar=True) # Update validation metrics with detached tensors self.val_metrics.update(y_pred.detach(), y.detach()) def test_step(self, batch, batch_idx): """Calculate the loss (MSE per default) and other metrics on the test set normalized and denormalized.""" x, y = batch y_pred = self(x) # Update the test metrics self.test_metrics.update(y_pred.detach(), y.detach()) def on_test_epoch_end(self): """Compute and log all metrics at the end of the test epoch.""" test_metrics_dict = self.test_metrics.compute() # Print all metrics in one line print(f"\n\nEpoch {self.current_epoch} - Test Metrics: {make_metrics_str(test_metrics_dict)}") # Reset the metrics for the next test epoch self.test_metrics.reset() def on_train_epoch_end(self): """Compute and log all metrics at the end of the training epoch.""" train_metrics_dict = self.train_metrics.compute() print(f"\n\nEpoch {self.current_epoch} - Train Metrics: {make_metrics_str(train_metrics_dict)}") self.train_metrics.reset() def on_validation_epoch_end(self): """Compute and log all metrics at the end of the validation epoch.""" val_metrics_dict = self.val_metrics.compute() print(f"\n\nEpoch {self.current_epoch} - Validation Metrics: {make_metrics_str(val_metrics_dict)}") self.val_metrics.reset() class PrecipRegressionBase(UNetBase): @staticmethod def add_model_specific_args(parent_parser): parent_parser = UNetBase.add_model_specific_args(parent_parser) parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument("--num_input_images", type=int, default=12) parser.add_argument("--num_output_images", type=int, default=6) parser.add_argument("--valid_size", type=float, default=0.1) parser.add_argument("--use_oversampled_dataset", type=bool, default=True) parser.n_channels = parser.parse_args().num_input_images parser.n_classes = 1 return parser def __init__(self, hparams): super().__init__(hparams=hparams) self.train_dataset = None self.valid_dataset = None self.train_sampler = None self.valid_sampler = None def prepare_data(self): # train_transform = transforms.Compose([ # transforms.RandomHorizontalFlip()] # ) train_transform = None valid_transform = None precip_dataset = ( dataset_precip.precipitation_maps_oversampled_h5 if self.hparams.use_oversampled_dataset else dataset_precip.precipitation_maps_h5 ) self.train_dataset = precip_dataset( in_file=self.hparams.dataset_folder, num_input_images=self.hparams.num_input_images, num_output_images=self.hparams.num_output_images, train=True, transform=train_transform, ) self.valid_dataset = precip_dataset( in_file=self.hparams.dataset_folder, num_input_images=self.hparams.num_input_images, num_output_images=self.hparams.num_output_images, train=True, transform=valid_transform, ) num_train = len(self.train_dataset) indices = list(range(num_train)) split = int(np.floor(self.hparams.valid_size * num_train)) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] self.train_sampler = SubsetRandomSampler(train_idx) self.valid_sampler = SubsetRandomSampler(valid_idx) def train_dataloader(self): train_loader = DataLoader( self.train_dataset, batch_size=self.hparams.batch_size, sampler=self.train_sampler, pin_memory=True, # The following can/should be tweaked depending on the number of CPU cores num_workers=1, persistent_workers=True, ) return train_loader def val_dataloader(self): valid_loader = DataLoader( self.valid_dataset, batch_size=self.hparams.batch_size, sampler=self.valid_sampler, pin_memory=True, # The following can/should be tweaked depending on the number of CPU cores num_workers=1, persistent_workers=True, ) return valid_loader class PersistenceModel(UNetBase): def forward(self, x): return x[:, -1:, :, :] ================================================ FILE: models/unet_parts.py ================================================ """ Parts of the U-Net model """ import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) ================================================ FILE: models/unet_parts_depthwise_separable.py ================================================ """ Parts of the U-Net model """ # Base model taken from: https://github.com/milesial/Pytorch-UNet import torch import torch.nn as nn import torch.nn.functional as F from models.layers import DepthwiseSeparableConv class DoubleConvDS(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None, kernels_per_layer=1): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( DepthwiseSeparableConv( in_channels, mid_channels, kernel_size=3, kernels_per_layer=kernels_per_layer, padding=1, ), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), DepthwiseSeparableConv( mid_channels, out_channels, kernel_size=3, kernels_per_layer=kernels_per_layer, padding=1, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_conv(x) class DownDS(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels, kernels_per_layer=1): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer), ) def forward(self, x): return self.maxpool_conv(x) class UpDS(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True, kernels_per_layer=1): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = DoubleConvDS( in_channels, out_channels, in_channels // 2, kernels_per_layer=kernels_per_layer, ) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) ================================================ FILE: models/unet_precip_regression_lightning.py ================================================ from models.unet_parts import Down, DoubleConv, Up, OutConv from models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS from models.layers import CBAM from models.regression_lightning import PrecipRegressionBase, PersistenceModel class UNet(PrecipRegressionBase): def __init__(self, hparams): super().__init__(hparams=hparams) self.n_channels = self.hparams.n_channels self.n_classes = self.hparams.n_classes self.bilinear = self.hparams.bilinear self.inc = DoubleConv(self.n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if self.bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, self.bilinear) self.up2 = Up(512, 256 // factor, self.bilinear) self.up3 = Up(256, 128 // factor, self.bilinear) self.up4 = Up(128, 64, self.bilinear) self.outc = OutConv(64, self.n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits class UNetAttention(PrecipRegressionBase): def __init__(self, hparams): super().__init__(hparams=hparams) self.n_channels = self.hparams.n_channels self.n_classes = self.hparams.n_classes self.bilinear = self.hparams.bilinear reduction_ratio = self.hparams.reduction_ratio self.inc = DoubleConv(self.n_channels, 64) self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio) self.down1 = Down(64, 128) self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio) self.down2 = Down(128, 256) self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio) self.down3 = Down(256, 512) self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio) factor = 2 if self.bilinear else 1 self.down4 = Down(512, 1024 // factor) self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) self.up1 = Up(1024, 512 // factor, self.bilinear) self.up2 = Up(512, 256 // factor, self.bilinear) self.up3 = Up(256, 128 // factor, self.bilinear) self.up4 = Up(128, 64, self.bilinear) self.outc = OutConv(64, self.n_classes) def forward(self, x): x1 = self.inc(x) x1Att = self.cbam1(x1) x2 = self.down1(x1) x2Att = self.cbam2(x2) x3 = self.down2(x2) x3Att = self.cbam3(x3) x4 = self.down3(x3) x4Att = self.cbam4(x4) x5 = self.down4(x4) x5Att = self.cbam5(x5) x = self.up1(x5Att, x4Att) x = self.up2(x, x3Att) x = self.up3(x, x2Att) x = self.up4(x, x1Att) logits = self.outc(x) return logits class UNetDS(PrecipRegressionBase): def __init__(self, hparams): super().__init__(hparams=hparams) self.n_channels = self.hparams.n_channels self.n_classes = self.hparams.n_classes self.bilinear = self.hparams.bilinear kernels_per_layer = self.hparams.kernels_per_layer self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer) self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) factor = 2 if self.bilinear else 1 self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer) self.outc = OutConv(64, self.n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits class UNetDSAttention(PrecipRegressionBase): def __init__(self, hparams): super().__init__(hparams=hparams) self.n_channels = self.hparams.n_channels self.n_classes = self.hparams.n_classes self.bilinear = self.hparams.bilinear reduction_ratio = self.hparams.reduction_ratio kernels_per_layer = self.hparams.kernels_per_layer self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer) self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio) self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio) self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio) self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio) factor = 2 if self.bilinear else 1 self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio) self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer) self.outc = OutConv(64, self.n_classes) def forward(self, x): x1 = self.inc(x) x1Att = self.cbam1(x1) x2 = self.down1(x1) x2Att = self.cbam2(x2) x3 = self.down2(x2) x3Att = self.cbam3(x3) x4 = self.down3(x3) x4Att = self.cbam4(x4) x5 = self.down4(x4) x5Att = self.cbam5(x5) x = self.up1(x5Att, x4Att) x = self.up2(x, x3Att) x = self.up3(x, x2Att) x = self.up4(x, x1Att) logits = self.outc(x) return logits class UNetDSAttention4CBAMs(PrecipRegressionBase): def __init__(self, hparams): super().__init__(hparams=hparams) self.n_channels = self.hparams.n_channels self.n_classes = self.hparams.n_classes self.bilinear = self.hparams.bilinear reduction_ratio = self.hparams.reduction_ratio kernels_per_layer = self.hparams.kernels_per_layer self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer) self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio) self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer) self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio) self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer) self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio) self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer) self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio) factor = 2 if self.bilinear else 1 self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer) self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer) self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer) self.outc = OutConv(64, self.n_classes) def forward(self, x): x1 = self.inc(x) x1Att = self.cbam1(x1) x2 = self.down1(x1) x2Att = self.cbam2(x2) x3 = self.down2(x2) x3Att = self.cbam3(x3) x4 = self.down3(x3) x4Att = self.cbam4(x4) x5 = self.down4(x4) x = self.up1(x5, x4Att) x = self.up2(x, x3Att) x = self.up3(x, x2Att) x = self.up4(x, x1Att) logits = self.outc(x) return logits ================================================ FILE: plot_examples.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import os\n", "import models.unet_precip_regression_lightning as unet_regr\n", "from utils import dataset_precip\n", "\n", "dataset = dataset_precip.precipitation_maps_oversampled_h5(\n", " in_file=\"data/precipitation/train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5\",\n", " num_input_images=12,\n", " num_output_images=6, train=False)\n", "# 222, 444, 777, 1337 is fine\n", "x, y_true = dataset[777]\n", "# add batch dimension\n", "x = torch.tensor(x).unsqueeze(0)\n", "y_true = torch.tensor(y_true).unsqueeze(0)\n", "print(x.shape, y_true.shape)\n", "\n", "models_dir = \"checkpoints/comparison\"\n", "precip_models = [m for m in os.listdir(models_dir) if \".ckpt\" in m]\n", "print(precip_models)\n", "\n", "loss_func = nn.MSELoss(reduction=\"sum\")\n", "print(\"Persistence\", loss_func(x.squeeze()[-1]*47.83, y_true.squeeze()*47.83).item())\n", "\n", "plot_images = dict()\n", "for i, model_file in enumerate(precip_models):\n", " # model_name = \"\".join(model_file.split(\"_\")[0])\n", " if \"UNetAttention\" in model_file:\n", " model_name = \"UNet with CBAM\"\n", " model = unet_regr.UNetAttention\n", " elif \"UNetDSAttention4CBAMs\" in model_file:\n", " model_name = \"UNetDS Attention 4CBAMs\"\n", " model = unet_regr.UNetDSAttention4CBAMs\n", " elif \"UNetDSAttention\" in model_file:\n", " model_name = \"SmaAt-UNet\"\n", " model = unet_regr.UNetDSAttention\n", " elif \"UNetDS\" in model_file:\n", " model_name = \"UNet with DSC\"\n", " model = unet_regr.UNetDS\n", " elif \"UNet\" in model_file:\n", " model_name = \"UNet\"\n", " model = unet_regr.UNet\n", " else:\n", " raise NotImplementedError(f\"Model not found\")\n", " model = model.load_from_checkpoint(f\"{models_dir}/{model_file}\")\n", " # loaded = torch.load(f\"{models_dir}/{model_file}\")\n", " # model = loaded[\"model\"]\n", " # model.load_state_dict(loaded[\"state_dict\"])\n", " # loss_func = nn.CrossEntropyLoss(weight=torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.]))\n", "\n", " model.to(\"cpu\")\n", " model.eval()\n", "\n", " # get prediction of model\n", " with torch.no_grad():\n", " y_pred = model(x)\n", " y_pred = y_pred.squeeze()\n", " print(model_name, loss_func(y_pred*47.83, y_true.squeeze()*47.83).item())\n", "\n", " plot_images[model_name] = y_pred" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#### Plotting ####\n", "plt.figure(figsize=(10, 20))\n", "plt.subplot(len(precip_models)+2, 1, 1)\n", "plt.imshow(y_true.squeeze()*47.83)\n", "plt.colorbar()\n", "plt.title(\"Ground truth\")\n", "plt.axis(\"off\")\n", "\n", "plt.subplot(len(precip_models)+2, 1, 2)\n", "plt.imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n", "plt.colorbar()\n", "plt.title(\"Persistence\")\n", "plt.axis(\"off\")\n", "\n", "for i, (model_name, y_pred) in enumerate(plot_images.items()):\n", " print(model_name)\n", " plt.subplot(len(precip_models)+2, 1, i+3)\n", " plt.imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n", " plt.colorbar()\n", " plt.title(model_name)\n", " plt.axis(\"off\")\n", "# plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, axesgrid = plt.subplots(nrows=2, ncols=(len(precip_models)+2)//2, figsize=(20, 10))\n", "axes = axesgrid.flat\n", "im = axes[0].imshow(y_true.squeeze()*47.83)\n", "axes[0].set_title(\"Groundtruth\", fontsize=15)\n", "axes[0].set_axis_off()\n", "axes[1].imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n", "axes[1].set_title(\"Persistence\", fontsize=15)\n", "axes[1].set_axis_off()\n", "\n", "# for ax in axes.flat[2:]:\n", "for i, (model_name, y_pred) in enumerate(plot_images.items()):\n", " print(model_name)\n", "# plt.subplot(len(precip_models)+2, 1, i+3)\n", " axes[i+2].imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n", " axes[i+2].set_title(model_name, fontsize=15)\n", " axes[i+2].set_axis_off()\n", "# im = ax.imshow(np.random.random((16, 16)), cmap='viridis',\n", "# vmin=0, vmax=1)\n", "\n", "output_dir = \"plots\"\n", "os.makedirs(output_dir, exist_ok=True)\n", "\n", "fig.subplots_adjust(wspace=0.02)\n", "cbar = fig.colorbar(im, ax=axesgrid.ravel().tolist(), shrink=1)\n", "cbar.ax.set_ylabel('mm/5min', fontsize=15)\n", "plt.savefig(os.path.join(output_dir, \"Precip_example.eps\"))\n", "plt.savefig(os.path.join(output_dir, \"Precip_example.png\"), dpi=300)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: pyproject.toml ================================================ [project] name = "smaat-unet" version = "0.1.0" description = "Code for the paper `SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture`" authors = [{ name = "Kevin Trebing", email = "Kevin.Trebing@gmx.net" }] readme = "README.md" requires-python = ">=3.10.11" dependencies = [ "h5py>=3.13.0", "lightning>=2.5.0.post0", "matplotlib>=3.10.0", "numpy>=2.2.3", "tensorboard>=2.19.0", "torch>=2.6.0", "torchsummary>=1.5.1", "tqdm>=4.67.1", ] [dependency-groups] dev = [ "mypy>=1.15.0", "pre-commit>=4.1.0", "ruff>=0.9.7", ] scripts = [ "ipykernel>=6.29.5", ] [[tool.uv.index]] name = "pytorch-cu126" url = "https://download.pytorch.org/whl/cu126" explicit = true [tool.uv.sources] torch = [ { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] torchvision = [ { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] [tool.mypy] python_version = "3.10" ignore_missing_imports = true [tool.ruff] line-length = 120 [tool.ruff.lint] select = [ "E", # pycodestyle Errors "W", # pycodestyle Warnings "F", # pyflakes "UP", # pyupgrade # "D", # pydocstyle "B", # flake8-bugbear "C4", # flake8-comprehensions "A", # flake8-builtins "C90", # mccabe complexity "DJ", # flake8-django "PIE", # flake8-pie # "SIM", # flake8-simplify ] ignore = [ "B905", # length-checking in zip() only introduced with Python 3.10 (PEP618) "UP007", # Python version 3.9 does not allow writing union types as X | Y "UP038", # Python version 3.9 does not allow writing union types as X | Y "D202", # No blank lines allowed after function docstring "D100", # Missing docstring in public module "D104", # Missing docstring in public package "D205", # 1 blank line required between summary line and description "D203", # One blank line before class docstring "D212", # Multi-line docstring summary should start at the first line "D213", # Multi-line docstring summary should start at the second line "D400", # First line of docstring should end with a period "D404", # First word of the docstring should not be "This" "D415", # First line should end with a period, question mark, or exclamation point "DJ001", # Avoid using `null=True` on string-based fields ] exclude = [ ".git", ".local", ".cache", ".venv", "./venv", ".vscode", "__pycache__", "docs", "build", "dist", "notebooks", "migrations" ] [tool.ruff.lint.mccabe] # Flag errors (`C901`) whenever the complexity level exceeds 12. max-complexity = 12 ================================================ FILE: requirements-dev.txt ================================================ # This file was autogenerated by uv via the following command: # uv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt absl-py==2.1.0 aiohappyeyeballs==2.4.6 aiohttp==3.11.12 aiosignal==1.3.2 async-timeout==5.0.1 ; python_full_version < '3.11' attrs==25.1.0 cfgv==3.4.0 colorama==0.4.6 ; sys_platform == 'win32' contourpy==1.3.1 cycler==0.12.1 distlib==0.3.9 filelock==3.17.0 fonttools==4.56.0 frozenlist==1.5.0 fsspec==2025.2.0 grpcio==1.70.0 h5py==3.13.0 identify==2.6.7 idna==3.10 jinja2==3.1.5 kiwisolver==1.4.8 lightning==2.5.0.post0 lightning-utilities==0.12.0 markdown==3.7 markupsafe==3.0.2 matplotlib==3.10.0 mpmath==1.3.0 multidict==6.1.0 mypy==1.15.0 mypy-extensions==1.0.0 networkx==3.4.2 nodeenv==1.9.1 numpy==2.2.3 nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' packaging==24.2 pillow==11.1.0 platformdirs==4.3.6 pre-commit==4.1.0 propcache==0.3.0 protobuf==5.29.3 pyparsing==3.2.1 python-dateutil==2.9.0.post0 pytorch-lightning==2.5.0.post0 pyyaml==6.0.2 ruff==0.9.7 setuptools==75.8.0 six==1.17.0 sympy==1.13.1 tensorboard==2.19.0 tensorboard-data-server==0.7.2 tomli==2.2.1 ; python_full_version < '3.11' torch==2.6.0 torchmetrics==1.6.1 torchsummary==1.5.1 tqdm==4.67.1 triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux' typing-extensions==4.12.2 virtualenv==20.29.2 werkzeug==3.1.3 yarl==1.18.3 ================================================ FILE: requirements-full.txt ================================================ # This file was autogenerated by uv via the following command: # uv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt absl-py==2.1.0 aiohappyeyeballs==2.4.6 aiohttp==3.11.12 aiosignal==1.3.2 appnope==0.1.4 ; sys_platform == 'darwin' asttokens==3.0.0 async-timeout==5.0.1 ; python_full_version < '3.11' attrs==25.1.0 cffi==1.17.1 ; implementation_name == 'pypy' cfgv==3.4.0 colorama==0.4.6 ; sys_platform == 'win32' comm==0.2.2 contourpy==1.3.1 cycler==0.12.1 debugpy==1.8.12 decorator==5.1.1 distlib==0.3.9 exceptiongroup==1.2.2 ; python_full_version < '3.11' executing==2.2.0 filelock==3.17.0 fonttools==4.56.0 frozenlist==1.5.0 fsspec==2025.2.0 grpcio==1.70.0 h5py==3.13.0 identify==2.6.7 idna==3.10 ipykernel==6.29.5 ipython==8.32.0 jedi==0.19.2 jinja2==3.1.5 jupyter-client==8.6.3 jupyter-core==5.7.2 kiwisolver==1.4.8 lightning==2.5.0.post0 lightning-utilities==0.12.0 markdown==3.7 markupsafe==3.0.2 matplotlib==3.10.0 matplotlib-inline==0.1.7 mpmath==1.3.0 multidict==6.1.0 mypy==1.15.0 mypy-extensions==1.0.0 nest-asyncio==1.6.0 networkx==3.4.2 nodeenv==1.9.1 numpy==2.2.3 nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' packaging==24.2 parso==0.8.4 pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' pillow==11.1.0 platformdirs==4.3.6 pre-commit==4.1.0 prompt-toolkit==3.0.50 propcache==0.3.0 protobuf==5.29.3 psutil==7.0.0 ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' pure-eval==0.2.3 pycparser==2.22 ; implementation_name == 'pypy' pygments==2.19.1 pyparsing==3.2.1 python-dateutil==2.9.0.post0 pytorch-lightning==2.5.0.post0 pywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32' pyyaml==6.0.2 pyzmq==26.2.1 ruff==0.9.7 setuptools==75.8.0 six==1.17.0 stack-data==0.6.3 sympy==1.13.1 tensorboard==2.19.0 tensorboard-data-server==0.7.2 tomli==2.2.1 ; python_full_version < '3.11' torch==2.6.0 torchmetrics==1.6.1 torchsummary==1.5.1 tornado==6.4.2 tqdm==4.67.1 traitlets==5.14.3 triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux' typing-extensions==4.12.2 virtualenv==20.29.2 wcwidth==0.2.13 werkzeug==3.1.3 yarl==1.18.3 ================================================ FILE: requirements.txt ================================================ # This file was autogenerated by uv via the following command: # uv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt absl-py==2.1.0 aiohappyeyeballs==2.4.6 aiohttp==3.11.12 aiosignal==1.3.2 async-timeout==5.0.1 ; python_full_version < '3.11' attrs==25.1.0 colorama==0.4.6 ; sys_platform == 'win32' contourpy==1.3.1 cycler==0.12.1 filelock==3.17.0 fonttools==4.56.0 frozenlist==1.5.0 fsspec==2025.2.0 grpcio==1.70.0 h5py==3.13.0 idna==3.10 jinja2==3.1.5 kiwisolver==1.4.8 lightning==2.5.0.post0 lightning-utilities==0.12.0 markdown==3.7 markupsafe==3.0.2 matplotlib==3.10.0 mpmath==1.3.0 multidict==6.1.0 networkx==3.4.2 numpy==2.2.3 nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' packaging==24.2 pillow==11.1.0 propcache==0.3.0 protobuf==5.29.3 pyparsing==3.2.1 python-dateutil==2.9.0.post0 pytorch-lightning==2.5.0.post0 pyyaml==6.0.2 setuptools==75.8.0 six==1.17.0 sympy==1.13.1 tensorboard==2.19.0 tensorboard-data-server==0.7.2 torch==2.6.0 torchmetrics==1.6.1 torchsummary==1.5.1 tqdm==4.67.1 triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux' typing-extensions==4.12.2 werkzeug==3.1.3 yarl==1.18.3 ================================================ FILE: root.py ================================================ from pathlib import Path ROOT_DIR = Path(__file__).parent ================================================ FILE: train_SmaAtUNet.py ================================================ from typing import Optional from models.SmaAt_UNet import SmaAt_UNet import torch from torch.utils.data import DataLoader from torch import optim from torch import nn from torchvision import transforms from root import ROOT_DIR from utils import dataset_VOC import time from tqdm import tqdm from metric import iou import os def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group["lr"] def fit( epochs, model, loss_func, opt, train_dl, valid_dl, dev=None, save_every: Optional[int] = None, tensorboard: bool = False, earlystopping=None, lr_scheduler=None, ): writer = None if tensorboard: from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(comment=f"{model.__class__.__name__}") if dev is None: dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") start_time = time.time() best_mIoU = -1.0 earlystopping_counter = 0 for epoch in tqdm(range(epochs), desc="Epochs", leave=True): model.train() train_loss = 0.0 for _, (xb, yb) in enumerate(tqdm(train_dl, desc="Batches", leave=False)): # for i, (xb, yb) in enumerate(train_dl): loss = loss_func(model(xb.to(dev)), yb.to(dev)) opt.zero_grad() loss.backward() opt.step() train_loss += loss.item() # if i > 100: # break train_loss /= len(train_dl) # Reduce learning rate after epoch # scheduler.step() # Calc validation loss val_loss = 0.0 iou_metric = iou.IoU(21, normalized=False) model.eval() with torch.no_grad(): for xb, yb in tqdm(valid_dl, desc="Validation", leave=False): # for xb, yb in valid_dl: y_pred = model(xb.to(dev)) loss = loss_func(y_pred, yb.to(dev)) val_loss += loss.item() # Calculate mean IOU pred_class = torch.argmax(nn.functional.softmax(y_pred, dim=1), dim=1) iou_metric.add(pred_class, target=yb) iou_class, mean_iou = iou_metric.value() val_loss /= len(valid_dl) # Save the model with the best mean IoU if mean_iou > best_mIoU: os.makedirs("checkpoints", exist_ok=True) torch.save( { "model": model, "epoch": epoch, "state_dict": model.state_dict(), "optimizer_state_dict": opt.state_dict(), "val_loss": val_loss, "train_loss": train_loss, "mIOU": mean_iou, }, ROOT_DIR / "checkpoints" / f"best_mIoU_model_{model.__class__.__name__}.pt", ) best_mIoU = mean_iou earlystopping_counter = 0 else: earlystopping_counter += 1 if earlystopping is not None and earlystopping_counter >= earlystopping: print(f"Stopping early --> mean IoU has not decreased over {earlystopping} epochs") break print( f"Epoch: {epoch:5d}, Time: {(time.time() - start_time) / 60:.3f} min," f"Train_loss: {train_loss:2.10f}, Val_loss: {val_loss:2.10f},", f"mIOU: {mean_iou:.10f},", f"lr: {get_lr(opt)},", f"Early stopping counter: {earlystopping_counter}/{earlystopping}" if earlystopping is not None else "", ) if writer: # add to tensorboard writer.add_scalar("Loss/train", train_loss, epoch) writer.add_scalar("Loss/val", val_loss, epoch) writer.add_scalar("Metric/mIOU", mean_iou, epoch) writer.add_scalar("Parameters/learning_rate", get_lr(opt), epoch) if save_every is not None and epoch % save_every == 0: # save model torch.save( { "model": model, "epoch": epoch, "state_dict": model.state_dict(), "optimizer_state_dict": opt.state_dict(), # 'scheduler_state_dict': scheduler.state_dict(), "val_loss": val_loss, "train_loss": train_loss, "mIOU": mean_iou, }, ROOT_DIR / "checkpoints" / f"model_{model.__class__.__name__}_epoch_{epoch}.pt", ) if lr_scheduler is not None: lr_scheduler.step(mean_iou) if __name__ == "__main__": dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") dataset_folder = ROOT_DIR / "data" / "VOCdevkit" batch_size = 8 learning_rate = 0.001 epochs = 200 earlystopping = 30 save_every = 1 # Load your dataset here transformations = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)]) voc_dataset_train = dataset_VOC.VOCSegmentation( root=dataset_folder, image_set="train", transformations=transformations, augmentations=True, ) voc_dataset_val = dataset_VOC.VOCSegmentation( root=dataset_folder, image_set="val", transformations=transformations, augmentations=False, ) train_dl = DataLoader( dataset=voc_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, ) valid_dl = DataLoader( dataset=voc_dataset_val, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, ) # Load SmaAt-UNet model = SmaAt_UNet(n_channels=3, n_classes=21) # Move model to device model.to(dev) # Define Optimizer and loss opt = optim.Adam(model.parameters(), lr=learning_rate) loss_func = nn.CrossEntropyLoss().to(dev) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.1, patience=4) # Train network fit( epochs=epochs, model=model, loss_func=loss_func, opt=opt, train_dl=train_dl, valid_dl=valid_dl, dev=dev, save_every=save_every, tensorboard=True, earlystopping=earlystopping, lr_scheduler=lr_scheduler, ) ================================================ FILE: train_precip_lightning.py ================================================ from root import ROOT_DIR import lightning.pytorch as pl from lightning.pytorch.callbacks import ( ModelCheckpoint, LearningRateMonitor, EarlyStopping, ) from lightning.pytorch import loggers import argparse from models import unet_precip_regression_lightning as unet_regr from lightning.pytorch.tuner import Tuner def train_regression(hparams, find_batch_size_automatically: bool = False): if hparams.model == "UNetDSAttention": net = unet_regr.UNetDSAttention(hparams=hparams) elif hparams.model == "UNetAttention": net = unet_regr.UNetAttention(hparams=hparams) elif hparams.model == "UNet": net = unet_regr.UNet(hparams=hparams) elif hparams.model == "UNetDS": net = unet_regr.UNetDS(hparams=hparams) else: raise NotImplementedError(f"Model '{hparams.model}' not implemented") default_save_path = ROOT_DIR / "lightning" / "precip_regression" checkpoint_callback = ModelCheckpoint( dirpath=default_save_path / net.__class__.__name__, filename=net.__class__.__name__ + "_rain_threshold_50_{epoch}-{val_loss:.6f}", save_top_k=1, verbose=False, monitor="val_loss", mode="min", ) last_checkpoint_callback = ModelCheckpoint( dirpath=default_save_path / net.__class__.__name__, filename=net.__class__.__name__ + "_rain_threshold_50_{epoch}-{val_loss:.6f}_last", save_top_k=1, verbose=False, ) lr_monitor = LearningRateMonitor() tb_logger = loggers.TensorBoardLogger(save_dir=default_save_path, name=net.__class__.__name__) earlystopping_callback = EarlyStopping( monitor="val_loss", mode="min", patience=hparams.es_patience, ) trainer = pl.Trainer( accelerator="gpu", devices=1, fast_dev_run=hparams.fast_dev_run, max_epochs=hparams.epochs, default_root_dir=default_save_path, logger=tb_logger, callbacks=[checkpoint_callback, last_checkpoint_callback, earlystopping_callback, lr_monitor], val_check_interval=hparams.val_check_interval, ) if find_batch_size_automatically: tuner = Tuner(trainer) # Auto-scale batch size by growing it exponentially (default) tuner.scale_batch_size(net, mode="binsearch") # This can be used to speed up training with newer GPUs: # https://lightning.ai/docs/pytorch/stable/advanced/speed.html#low-precision-matrix-multiplication # torch.set_float32_matmul_precision('medium') trainer.fit(model=net, ckpt_path=hparams.resume_from_checkpoint) if __name__ == "__main__": parser = argparse.ArgumentParser() parser = unet_regr.PrecipRegressionBase.add_model_specific_args(parser) parser.add_argument( "--dataset_folder", default=ROOT_DIR / "data" / "precipitation" / "RAD_NL25_RAC_5min_train_test_2016-2019.h5", type=str, ) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--learning_rate", type=float, default=0.001) parser.add_argument("--epochs", type=int, default=200) parser.add_argument("--fast_dev_run", type=bool, default=False) parser.add_argument("--resume_from_checkpoint", type=str, default=None) parser.add_argument("--val_check_interval", type=float, default=None) args = parser.parse_args() # args.fast_dev_run = True args.n_channels = 12 # args.gpus = 1 args.model = "UNetDSAttention" args.lr_patience = 4 args.es_patience = 15 # args.val_check_interval = 0.25 args.kernels_per_layer = 2 args.use_oversampled_dataset = True args.dataset_folder = ( ROOT_DIR / "data" / "precipitation" / "train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5" ) # args.resume_from_checkpoint = f"lightning/precip_regression/{args.model}/UNetDSAttention.ckpt" # train_regression(args, find_batch_size_automatically=False) # All the models below will be trained for m in ["UNet", "UNetDS", "UNetAttention", "UNetDSAttention"]: args.model = m print(f"Start training model: {m}") train_regression(args, find_batch_size_automatically=False) ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/data_loader_precip.py ================================================ import torch import numpy as np from torchvision import transforms from torch.utils.data.sampler import SubsetRandomSampler from utils import dataset_precip from root import ROOT_DIR # Taken from: https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb def get_train_valid_loader( data_dir, batch_size, random_seed, num_input_images, num_output_images, augment, classification, valid_size=0.1, shuffle=True, num_workers=4, pin_memory=False, ): """ Utility function for loading and returning train and valid multi-process iterators over the CIFAR-10 dataset. A sample 9x9 grid of the images can be optionally displayed. If using CUDA, num_workers should be set to 1 and pin_memory to True. Params ------ - data_dir: path directory to the dataset. - batch_size: how many samples per batch to load. - augment: whether to apply the data augmentation scheme mentioned in the paper. Only applied on the train split. - random_seed: fix seed for reproducibility. - valid_size: percentage split of the training set used for the validation set. Should be a float in the range [0, 1]. - shuffle: whether to shuffle the train/validation indices. - num_workers: number of subprocesses to use when loading the dataset. - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to True if using GPU. Returns ------- - train_loader: training set iterator. - valid_loader: validation set iterator. """ error_msg = "[!] valid_size should be in the range [0, 1]." assert (valid_size >= 0) and (valid_size <= 1), error_msg # Since I am not dealing with RGB images I do not need this # normalize = transforms.Normalize( # mean=[0.4914, 0.4822, 0.4465], # std=[0.2023, 0.1994, 0.2010], # ) # define transforms valid_transform = None # valid_transform = transforms.Compose([ # transforms.ToTensor(), # normalize, # ]) if augment: # TODO flipping, rotating, sequence flipping (torch.flip seems very expensive) train_transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, ] ) else: train_transform = None # train_transform = transforms.Compose([ # transforms.ToTensor(), # normalize, # ]) if classification: # load the dataset train_dataset = dataset_precip.precipitation_maps_classification_h5( in_file=data_dir, num_input_images=num_input_images, img_to_predict=num_output_images, train=True, transform=train_transform, ) valid_dataset = dataset_precip.precipitation_maps_classification_h5( in_file=data_dir, num_input_images=num_input_images, img_to_predict=num_output_images, train=True, transform=valid_transform, ) else: # load the dataset train_dataset = dataset_precip.precipitation_maps_h5( in_file=data_dir, num_input_images=num_input_images, num_output_images=num_output_images, train=True, transform=train_transform, ) valid_dataset = dataset_precip.precipitation_maps_h5( in_file=data_dir, num_input_images=num_input_images, num_output_images=num_output_images, train=True, transform=valid_transform, ) num_train = len(train_dataset) indices = list(range(num_train)) split = int(np.floor(valid_size * num_train)) if shuffle: np.random.seed(random_seed) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory, ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=pin_memory, ) return train_loader, valid_loader def get_test_loader( data_dir, batch_size, num_input_images, num_output_images, classification, shuffle=False, num_workers=4, pin_memory=False, ): """ Utility function for loading and returning a multi-process test iterator over the CIFAR-10 dataset. If using CUDA, num_workers should be set to 1 and pin_memory to True. Params ------ - data_dir: path directory to the dataset. - batch_size: how many samples per batch to load. - shuffle: whether to shuffle the dataset after every epoch. - num_workers: number of subprocesses to use when loading the dataset. - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to True if using GPU. Returns ------- - data_loader: test set iterator. """ # Since I am not dealing with RGB images I do not need this # normalize = transforms.Normalize( # mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225], # ) # define transform transform = None # transform = transforms.Compose([ # transforms.ToTensor(), # # normalize, # ]) if classification: dataset = dataset_precip.precipitation_maps_classification_h5( in_file=data_dir, num_input_images=num_input_images, img_to_predict=num_output_images, train=False, transform=transform, ) else: dataset = dataset_precip.precipitation_maps_h5( in_file=data_dir, num_input_images=num_input_images, num_output_images=num_output_images, train=False, transform=transform, ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, ) return data_loader if __name__ == "__main__": folder = ROOT_DIR / "data" / "precipitation" data = "RAD_NL25_RAC_5min_train_test_2016-2019.h5" train_dl, valid_dl = get_train_valid_loader( folder / data, batch_size=8, random_seed=1337, num_input_images=12, num_output_images=6, classification=True, augment=False, valid_size=0.1, shuffle=True, num_workers=4, pin_memory=False, ) for xb, yb in train_dl: print("xb.shape: ", xb.shape) print("yb.shape: ", yb.shape) break ================================================ FILE: utils/dataset_VOC.py ================================================ import random import torch from torch.utils.data import Dataset from PIL import Image from pathlib import Path from torchvision import transforms import torchvision.transforms.functional as TF import numpy as np import matplotlib.pyplot as plt def get_pascal_labels(): """Load the mapping that associates pascal classes with label colors Returns: np.ndarray with dimensions (21, 3) """ return np.asarray( [ [0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128], ] ) def decode_segmap(label_mask, plot=False): """Decode segmentation class labels into a color image Args: label_mask (np.ndarray): an (M,N) array of integer values denoting the class label at each spatial location. plot (bool, optional): whether to show the resulting color image in a figure. Returns: (np.ndarray, optional): the resulting decoded color image. """ label_colours = get_pascal_labels() r = label_mask.copy() g = label_mask.copy() b = label_mask.copy() for ll in range(21): r[label_mask == ll] = label_colours[ll, 0] g[label_mask == ll] = label_colours[ll, 1] b[label_mask == ll] = label_colours[ll, 2] rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) rgb[:, :, 0] = r / 255.0 rgb[:, :, 1] = g / 255.0 rgb[:, :, 2] = b / 255.0 if plot: plt.imshow(rgb) plt.show() else: return rgb class VOCSegmentation(Dataset): CLASS_NAMES = [ "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "potted-plant", "sheep", "sofa", "train", "tv/monitor", ] def __init__(self, root: Path, image_set="train", transformations=None, augmentations=False): super().__init__() assert image_set in ["train", "val", "trainval"] voc_root = root / "VOC2012" image_dir = voc_root / "JPEGImages" mask_dir = voc_root / "SegmentationClass" splits_dir = voc_root / "ImageSets" / "Segmentation" split_f = splits_dir / (image_set.rstrip("\n") + ".txt") with open(split_f) as f: file_names = [x.strip() for x in f.readlines()] self.images = [image_dir / (x + ".jpg") for x in file_names] self.masks = [mask_dir / (x + ".png") for x in file_names] assert len(self.images) == len(self.masks) self.transformations = transformations self.augmentations = augmentations def __getitem__(self, index): img = Image.open(self.images[index]).convert("RGB") target = Image.open(self.masks[index]) if self.transformations is not None: img = self.transformations(img) target = self.transformations(target) # Apply augmentations to both input and target if self.augmentations: img, target = self.apply_augmentations(img, target) # Convert the RGB image to a tensor toTensorTransform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) img = toTensorTransform(img) # Convert target to long tensor target = torch.from_numpy(np.array(target)).long() target[target == 255] = 0 return img, target def __len__(self): return len(self.images) def apply_augmentations(self, img, target): # Horizontal flip if random.random() > 0.5: img = TF.hflip(img) target = TF.hflip(target) # Random Rotation (clockwise and counterclockwise) if random.random() > 0.5: degrees = 10 if random.random() > 0.5: degrees *= -1 img = TF.rotate(img, degrees) target = TF.rotate(target, degrees) # Brighten or darken image (only applied to input image) if random.random() > 0.5: brightness = 1.2 if random.random() > 0.5: brightness -= 0.4 img = TF.adjust_brightness(img, brightness) return img, target ================================================ FILE: utils/dataset_precip.py ================================================ from torch.utils.data import Dataset import h5py import numpy as np class precipitation_maps_h5(Dataset): def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None): super().__init__() self.file_name = in_file self.n_images, self.nx, self.ny = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape self.num_input = num_input_images self.num_output = num_output_images self.sequence_length = num_input_images + num_output_images self.train = train # Dataset is all the images self.size_dataset = self.n_images - (num_input_images + num_output_images) # self.size_dataset = int(self.n_images/(num_input_images+num_output_images)) self.transform = transform self.dataset = None def __getitem__(self, index): # min_feature_range = 0.0 # max_feature_range = 1.0 # with h5py.File(self.file_name, 'r') as dataFile: # dataset = dataFile["train" if self.train else "test"]['images'][index:index+self.sequence_length] # load the file here (load as singleton) if self.dataset is None: self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][ "images" ] imgs = np.array(self.dataset[index : index + self.sequence_length], dtype="float32") # add transforms if self.transform is not None: imgs = self.transform(imgs) input_img = imgs[: self.num_input] target_img = imgs[-1] return input_img, target_img def __len__(self): return self.size_dataset class precipitation_maps_oversampled_h5(Dataset): def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None): super().__init__() self.file_name = in_file self.samples, _, _, _ = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape self.num_input = num_input_images self.num_output = num_output_images self.train = train # self.size_dataset = int(self.n_images/(num_input_images+num_output_images)) self.transform = transform self.dataset = None def __getitem__(self, index): # load the file here (load as singleton) if self.dataset is None: self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][ "images" ] imgs = np.array(self.dataset[index], dtype="float32") # add transforms if self.transform is not None: imgs = self.transform(imgs) input_img = imgs[: self.num_input] target_img = imgs[-1] return input_img, target_img def __len__(self): return self.samples class precipitation_maps_classification_h5(Dataset): def __init__(self, in_file, num_input_images, img_to_predict, train=True, transform=None): super().__init__() self.file_name = in_file self.n_images, self.nx, self.ny = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape self.num_input = num_input_images self.img_to_predict = img_to_predict self.sequence_length = num_input_images + img_to_predict self.bins = np.array([0.0, 0.5, 1, 2, 5, 10, 30]) self.train = train # Dataset is all the images self.size_dataset = self.n_images - (num_input_images + img_to_predict) # self.size_dataset = int(self.n_images/(num_input_images+num_output_images)) self.transform = transform self.dataset = None def __getitem__(self, index): # min_feature_range = 0.0 # max_feature_range = 1.0 # with h5py.File(self.file_name, 'r') as dataFile: # dataset = dataFile["train" if self.train else "test"]['images'][index:index+self.sequence_length] # load the file here (load as singleton) if self.dataset is None: self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][ "images" ] imgs = np.array(self.dataset[index : index + self.sequence_length], dtype="float32") # add transforms if self.transform is not None: imgs = self.transform(imgs) input_img = imgs[: self.num_input] # put the img in buckets target_img = imgs[-1] # target_img is normalized by dividing through the highest value of the training set. We reverse this. # Then target_img is in mm/5min. The bins have the unit mm/hour. Therefore we multiply the img by 12 buckets = np.digitize(target_img * 47.83 * 12, self.bins, right=True) return input_img, buckets def __len__(self): return self.size_dataset ================================================ FILE: utils/formatting.py ================================================ import torch import numpy as np def make_metrics_str(metrics_dict): return " | ".join([f"{name}: {value.item() if isinstance(value, torch.Tensor) else value:.4f}" for name, value in metrics_dict.items() if not (isinstance(value, torch.Tensor) and torch.isnan(value)) and not (not isinstance(value, torch.Tensor) and np.isnan(value))]) ================================================ FILE: utils/model_classes.py ================================================ from models import unet_precip_regression_lightning as unet_regr import lightning.pytorch as pl def get_model_class(model_file) -> tuple[type[pl.LightningModule], str]: # This is for some nice plotting if "UNetAttention" in model_file: model_name = "UNet Attention" model = unet_regr.UNetAttention elif "UNetDSAttention4kpl" in model_file: model_name = "UNetDS Attention with 4kpl" model = unet_regr.UNetDSAttention elif "UNetDSAttention1kpl" in model_file: model_name = "UNetDS Attention with 1kpl" model = unet_regr.UNetDSAttention elif "UNetDSAttention4CBAMs" in model_file: model_name = "UNetDS Attention 4CBAMs" model = unet_regr.UNetDSAttention4CBAMs elif "UNetDSAttention" in model_file: model_name = "SmaAt-UNet" model = unet_regr.UNetDSAttention elif "UNetDS" in model_file: model_name = "UNetDS" model = unet_regr.UNetDS elif "UNet" in model_file: model_name = "UNet" model = unet_regr.UNet elif "PersistenceModel" in model_file: model_name = "PersistenceModel" model = unet_regr.PersistenceModel else: raise NotImplementedError("Model not found") return model, model_name