Repository: HansBambel/SmaAt-UNet
Branch: master
Commit: 3e48dcca3e88
Files: 32
Total size: 108.5 KB
Directory structure:
gitextract_3pkgi8mx/
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── README.md
├── calc_metrics_test_set.py
├── create_datasets.py
├── metric/
│ ├── __init__.py
│ ├── confusionmatrix.py
│ ├── iou.py
│ ├── metric.py
│ └── precipitation_metrics.py
├── models/
│ ├── SmaAt_UNet.py
│ ├── __init__.py
│ ├── layers.py
│ ├── regression_lightning.py
│ ├── unet_parts.py
│ ├── unet_parts_depthwise_separable.py
│ └── unet_precip_regression_lightning.py
├── plot_examples.ipynb
├── pyproject.toml
├── requirements-dev.txt
├── requirements-full.txt
├── requirements.txt
├── root.py
├── train_SmaAtUNet.py
├── train_precip_lightning.py
└── utils/
├── __init__.py
├── data_loader_precip.py
├── dataset_VOC.py
├── dataset_precip.py
├── formatting.py
└── model_classes.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea
.vscode
data
runs
lightning
checkpoints
plots
.ruff_cache
.mypy_cache
__pycache__
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.9.7'
hooks:
- id: ruff
args: ["--fix"]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
================================================
FILE: .python-version
================================================
3.10.11
================================================
FILE: README.md
================================================
# SmaAt-UNet
Code for the Paper "SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture" [Arxiv-link](https://arxiv.org/abs/2007.04417), [Elsevier-link](https://www.sciencedirect.com/science/article/pii/S0167865521000556?via%3Dihub)
Other relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting).

The proposed SmaAt-UNet can be found in the model-folder under [SmaAt_UNet](models/SmaAt_UNet.py).
**>>>IMPORTANT<<<**
The original Code from the paper can be found in this branch: https://github.com/HansBambel/SmaAt-UNet/tree/snapshot-paper
The current master branch has since upgraded packages and was refactored. Since the exact package-versions differ the experiments may not be 100% reproducible.
If you have problems running the code, feel free to open an issue here on Github.
## Installing dependencies
This project is using [uv](https://docs.astral.sh/uv/) as dependency management. Therefore, installing the required dependencies is as easy as this:
```shell
uv sync --frozen
```
If you have Nvidia GPU(s) that support CUDA, and running the command above have not installed the torch with CUDA support, you can install the relevant dependencies with the following command:
```shell
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
```
In any case a [`requirements.txt`](requirements.txt) file are generated from `uv.lock`, which implies that `uv` must be installed in order to perform the export:
```shell
uv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt
```
A `requirements-dev.txt` file is also generated for development dependencies:
```shell
uv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt
```
A `requirements-full.txt` file is also generated for all dependencies (including development and additional dependencies):
```shell
uv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt
```
Alternatively, users can use the existing [`requirements.txt`](requirements.txt), [`requirements-dev.txt`](requirements-dev.txt) and [`requirements-full.txt`](requirements-full.txt) files.
Basically, only the following requirements are needed for the training:
```
tqdm
torch
lightning
matplotlib
tensorboard
torchsummary
h5py
numpy
```
Additional requirements are needed for the [testing](###Testing):
```
ipykernel
```
---
For the paper we used the [Lightning](https://github.com/Lightning-AI/lightning) -module (PL) which simplifies the training process and allows easy additions of loggers and checkpoint creations.
In order to use PL we created the model [UNetDSAttention](models/unet_precip_regression_lightning.py) whose parent inherits from the pl.LightningModule. This model is the same as the pure PyTorch SmaAt-UNet implementation with the added PL functions.
### Training
An example [training script](train_SmaAtUNet.py) is given for a classification task (PascalVOC).
For training on the precipitation task we used the [train_precip_lightning.py](train_precip_lightning.py) file.
The training will place a checkpoint file for every model in the `default_save_path` `lightning/precip_regression`.
After finishing training place the best models (probably the ones with the lowest validation loss) that you want to compare in another folder in `checkpoints/comparison`.
### Testing
The script [calc_metrics_test_set.py](calc_metrics_test_set.py) will use all the models placed in `checkpoints/comparison` to calculate the MSEs and other metrics such as Precision, Recall, Accuracy, F1, CSI, FAR, HSS.
The results will get saved in a json in the same folder as the models.
The metrics of the persistence model (for now) are only calculated using the script [test_precip_lightning.py](test_precip_lightning.py). This script will also use all models in that folder and calculate the test-losses for the models in addition to the persistence model.
It will get handled by [this issue](https://github.com/HansBambel/SmaAt-UNet/issues/28).
### Plots
Example code for creating similar plots as in the paper can be found in [plot_examples.ipynb](plot_examples.ipynb).
You might need to make your kernel discoverable so that Jupyter can detect it. To do so, run the following command in your terminal:
```shell
uv run ipython kernel install --user --env VIRTUAL_ENV=$(pwd)/.venv --name=smaat_unet --diplay-name="SmaAt-UNet Kernel"
```
### Precipitation dataset
The dataset consists of precipitation maps in 5-minute intervals from 2016-2019 resulting in about 420,000 images.
The dataset is based on radar precipitation maps from the [The Royal Netherlands Meteorological Institute (KNMI)](https://www.knmi.nl/over-het-knmi/about).
The original images were cropped as can be seen in the example below:

If you are interested in the dataset that we used please write an e-mail to: k.trebing@alumni.maastrichtuniversity.nl and s.mehrkanoon@uu.nl
The 50% dataset has 4GB in size and the 20% dataset has 16.5GB in size. Use the [create_dataset.py](create_datasets.py) to create the two datasets used from the original dataset.
The dataset is already normalized using a [Min-Max normalization](https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization)).
In order to revert this you need to multiply the images by 47.83; this results in the images showing the mm/5min.
### Citation
```
@article{TREBING2021,
title = {SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture},
journal = {Pattern Recognition Letters},
year = {2021},
issn = {0167-8655},
doi = {https://doi.org/10.1016/j.patrec.2021.01.036},
url = {https://www.sciencedirect.com/science/article/pii/S0167865521000556},
author = {Kevin Trebing and Tomasz Staǹczyk and Siamak Mehrkanoon},
keywords = {Domain adaptation, neural networks, kernel methods, coupling regularization},
abstract = {Weather forecasting is dominated by numerical weather prediction that tries to model accurately the physical properties of the atmosphere. A downside of numerical weather prediction is that it is lacking the ability for short-term forecasts using the latest available information. By using a data-driven neural network approach we show that it is possible to produce an accurate precipitation nowcast. To this end, we propose SmaAt-UNet, an efficient convolutional neural networks-based on the well known UNet architecture equipped with attention modules and depthwise-separable convolutions. We evaluate our approaches on a real-life datasets using precipitation maps from the region of the Netherlands and binary images of cloud coverage of France. The experimental results show that in terms of prediction performance, the proposed model is comparable to other examined models while only using a quarter of the trainable parameters.}
}
```
### Other relevant publications
Your project might also benefit from exploring our other relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting). Please consider citing them if you find them useful.
================================================
FILE: calc_metrics_test_set.py
================================================
import argparse
from argparse import Namespace
import json
import csv
import numpy as np
import os
from pprint import pprint
import torch
from tqdm import tqdm
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from metric.precipitation_metrics import PrecipitationMetrics
from models.unet_precip_regression_lightning import PersistenceModel
from root import ROOT_DIR
from utils import dataset_precip, model_classes
def parse_args():
parser = argparse.ArgumentParser(description="Calculate metrics for precipitation models")
parser.add_argument("--model-folder", type=str, default=None,
help="Path to model folder (default: ROOT_DIR/checkpoints/comparison)")
parser.add_argument("--threshold", type=float, default=0.5, help="Precipitation threshold)")
parser.add_argument("--normalized", action="store_false", dest="denormalize", default=True,
help="Calculate metrics for normalized data")
parser.add_argument("--load-metrics", action="store_true",
help="Load existing metrics instead of running experiments")
parser.add_argument("--file-type", type=str, default="txt", choices=["json", "txt", "csv"],
help="Output file format (json, txt, or csv)")
parser.add_argument("--save-dir", type=str, default=None,
help="Path to save the metrics")
parser.add_argument("--plot", action="store_true", help="Plot the metrics")
return parser.parse_args()
def convert_tensors_to_python(obj):
"""Convert PyTorch tensors to Python native types recursively."""
if isinstance(obj, torch.Tensor):
# Convert tensor to float/int, handling nan values
value = obj.item() if obj.numel() == 1 else obj.tolist()
return float('nan') if isinstance(value, float) and np.isnan(value) else value
elif isinstance(obj, dict):
return {key: convert_tensors_to_python(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_tensors_to_python(value) for value in obj]
return obj
def save_metrics(results, file_path, file_type="json"):
"""
Save metrics to a file in the specified format.
Args:
results (dict): Dictionary containing model metrics
file_path (Path): Path to save the file
file_type (str): File format - "json", "txt", or "csv"
"""
with open(file_path, "w") as f:
if file_type == "csv":
writer = csv.writer(f)
# Write header row
if results:
first_model = next(iter(results.values()))
writer.writerow(["Model"] + list(first_model.keys()))
# Write data rows
for model_name, metrics in results.items():
writer.writerow([model_name] + list(metrics.values()))
else:
# For json and txt formats, use json.dump
json.dump(results, f, indent=4)
def run_experiments(model_folder, data_file, threshold=0.5, denormalize=True):
"""
Run test experiments for all models in the model folder.
Args:
model_folder (Path): Path to the model folder
data_file (Path): Path to the data file
threshold (float): Threshold for the precipitation
Returns:
dict: Dictionary containing model metrics
"""
results = {}
# Load the dataset
dataset = dataset_precip.precipitation_maps_oversampled_h5(
in_file=data_file, num_input_images=12, num_output_images=6, train=False
)
# Create dataloader
test_dl = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False, pin_memory=True, persistent_workers=True)
# Create trainer with device specification
trainer = pl.Trainer(logger=False, enable_checkpointing=False)
# Find all models in the model folder
models = ["PersistenceModel"] + [f for f in os.listdir(model_folder) if f.endswith(".ckpt")]
if not models:
raise ValueError("No checkpoint files found in the model folder.")
for model_file in tqdm(models, desc="Models", leave=True):
model, model_name = model_classes.get_model_class(model_file)
print(f"Loading model: {model_name}")
if model_file == "PersistenceModel":
loaded_model = model(Namespace())
else:
loaded_model = model.load_from_checkpoint(os.path.join(model_folder, model_file))
# Override the test metrics, required for testing with normalized data
loaded_model.test_metrics = PrecipitationMetrics(threshold=threshold, denormalize=denormalize)
results[model_name] = trainer.test(model=loaded_model, dataloaders=[test_dl])[0]
return results
def plot_metrics(results, save_path=""):
"""
Plot all metrics from the results dictionary.
Args:
results (dict): Dictionary containing metrics for different models
save_path (str): Path to save the plots
"""
model_names = list(results.keys())
metrics = list(results[model_names[0]].keys())
for metric_name in metrics:
# Get values, replacing NaN with None to skip them in the plot
values = []
valid_names = []
for model_name in model_names:
val = results[model_name][metric_name]
# Skip NaN values
if not (isinstance(val, float) and np.isnan(val)):
values.append(val)
valid_names.append(model_name)
if len(valid_names) == 0:
continue
plt.figure(figsize=(10, 6))
plt.bar(valid_names, values)
plt.xticks(rotation=45)
plt.xlabel("Models")
plt.ylabel(f"{metric_name.upper()}")
plt.title(f"Comparison of different models - {metric_name.upper()}")
plt.tight_layout()
if save_path:
plt.savefig(f"{save_path}/{metric_name}.png")
plt.show()
def load_metrics_from_file(file_path, file_type):
"""
Load metrics from a file in the specified format.
Args:
file_path (Path): Path to the metrics file
file_type (str): File format - "json", "txt", or "csv"
Returns:
dict: Dictionary containing model metrics
"""
results = {}
with open(file_path) as f:
if file_type == "json" or file_type == "txt":
# Both json and txt files are saved in JSON format in this application
results = json.load(f)
elif file_type == "csv":
reader = csv.reader(f)
headers = next(reader) # Get header row
for row in reader:
if len(row) > 0:
model_name = row[0]
# Create dictionary of metrics for this model
model_metrics = {}
for i in range(1, len(headers)):
if i < len(row):
# Try to convert to float if possible
try:
value = float(row[i])
except (ValueError, TypeError):
value = row[i]
model_metrics[headers[i]] = value
results[model_name] = model_metrics
return results
def main():
# Parse command line arguments
args = parse_args()
# Variables from command line arguments
load_metrics = args.load_metrics
threshold = args.threshold
file_type = args.file_type
denormalize = args.denormalize
# Define paths
model_folder = ROOT_DIR / "checkpoints" / "comparison" if args.model_folder is None else args.model_folder
data_file = ROOT_DIR / "data" / "precipitation" / f"train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_{int(threshold*100)}.h5"
save_dir = model_folder / f"calc_metrics_test_set_results_{'denorm' if denormalize else 'norm'}" if args.save_dir is None else args.save_dir
os.makedirs(save_dir, exist_ok=True)
file_name = f"model_metrics_{threshold}_{'denorm' if denormalize else 'norm'}.{file_type}"
file_path = save_dir / file_name
# Load metrics if available
if load_metrics:
print(f"Loading metrics from {file_path}")
results = load_metrics_from_file(file_path, file_type)
else:
# Run experiments
print(f"Running experiments with {threshold} threshold and {'denormalized' if denormalize else 'normalized'} data")
results = run_experiments(model_folder, data_file, threshold=threshold, denormalize=denormalize)
# Convert tensors to Python native types
results = convert_tensors_to_python(results)
save_metrics(results, file_path, file_type)
# Display results
print(f"Metrics saved to {file_path}")
pprint(results)
# Plot metrics if requested
if args.plot:
plots_dir = save_dir / "plots"
os.makedirs(plots_dir, exist_ok=True)
plot_metrics(results, plots_dir)
print(f"Plots saved to {plots_dir}")
if __name__ == "__main__":
main()
================================================
FILE: create_datasets.py
================================================
import h5py
import numpy as np
from tqdm import tqdm
from root import ROOT_DIR
def create_dataset(input_length: int, image_ahead: int, rain_amount_thresh: float):
"""Create a dataset that has target images containing at least `rain_amount_thresh` (percent) of rain."""
precipitation_folder = ROOT_DIR / "data" / "precipitation"
with h5py.File(
precipitation_folder / "RAD_NL25_RAC_5min_train_test_2016-2019.h5",
"r",
) as orig_f:
train_images = orig_f["train"]["images"]
train_timestamps = orig_f["train"]["timestamps"]
test_images = orig_f["test"]["images"]
test_timestamps = orig_f["test"]["timestamps"]
print("Train shape", train_images.shape)
print("Test shape", test_images.shape)
imgSize = train_images.shape[1]
num_pixels = imgSize * imgSize
filename = (
precipitation_folder / f"train_test_2016-2019_input-length_{input_length}_img-"
f"ahead_{image_ahead}_rain-threshold_{int(rain_amount_thresh * 100)}.h5"
)
with h5py.File(filename, "w") as f:
train_set = f.create_group("train")
test_set = f.create_group("test")
train_image_dataset = train_set.create_dataset(
"images",
shape=(1, input_length + image_ahead, imgSize, imgSize),
maxshape=(None, input_length + image_ahead, imgSize, imgSize),
dtype="float32",
compression="gzip",
compression_opts=9,
)
train_timestamp_dataset = train_set.create_dataset(
"timestamps",
shape=(1, input_length + image_ahead, 1),
maxshape=(None, input_length + image_ahead, 1),
dtype=h5py.special_dtype(vlen=str),
compression="gzip",
compression_opts=9,
)
test_image_dataset = test_set.create_dataset(
"images",
shape=(1, input_length + image_ahead, imgSize, imgSize),
maxshape=(None, input_length + image_ahead, imgSize, imgSize),
dtype="float32",
compression="gzip",
compression_opts=9,
)
test_timestamp_dataset = test_set.create_dataset(
"timestamps",
shape=(1, input_length + image_ahead, 1),
maxshape=(None, input_length + image_ahead, 1),
dtype=h5py.special_dtype(vlen=str),
compression="gzip",
compression_opts=9,
)
origin = [[train_images, train_timestamps], [test_images, test_timestamps]]
datasets = [
[train_image_dataset, train_timestamp_dataset],
[test_image_dataset, test_timestamp_dataset],
]
for origin_id, (images, timestamps) in enumerate(origin):
image_dataset, timestamp_dataset = datasets[origin_id]
# Pre-calculate all valid indices that meet the rain threshold
sequence_length = input_length + image_ahead
valid_indices = []
# Use vectorized operations to find valid indices
for i in tqdm(range(sequence_length, len(images)), desc="Finding valid indices"):
rain_pixels = np.sum(images[i] > 0)
if rain_pixels >= num_pixels * rain_amount_thresh:
valid_indices.append(i)
# Pre-allocate the final dataset size
total_sequences = len(valid_indices)
image_dataset.resize(total_sequences, axis=0)
timestamp_dataset.resize(total_sequences, axis=0)
# Batch process the sequences
for idx, i in enumerate(tqdm(valid_indices, desc="Creating sequences")):
imgs = images[i - sequence_length : i]
timestamps_img = timestamps[i - sequence_length : i]
image_dataset[idx] = imgs
timestamp_dataset[idx] = timestamps_img
if __name__ == "__main__":
print("Creating dataset with at least 20% of rain pixel in target image")
create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.2)
print("Creating dataset with at least 50% of rain pixel in target image")
create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.5)
================================================
FILE: metric/__init__.py
================================================
from .confusionmatrix import ConfusionMatrix
from .iou import IoU
from .metric import Metric
__all__ = ["ConfusionMatrix", "IoU", "Metric"]
================================================
FILE: metric/confusionmatrix.py
================================================
import numpy as np
import torch
from metric import metric
class ConfusionMatrix(metric.Metric):
"""Constructs a confusion matrix for a multi-class classification problems.
Does not support multi-label, multi-class problems.
Keyword arguments:
- num_classes (int): number of classes in the classification problem.
- normalized (boolean, optional): Determines whether or not the confusion
matrix is normalized or not. Default: False.
Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
"""
def __init__(self, num_classes, normalized=False):
super().__init__()
self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
self.normalized = normalized
self.num_classes = num_classes
self.reset()
def reset(self):
self.conf.fill(0)
def add(self, predicted, target):
"""Computes the confusion matrix
The shape of the confusion matrix is K x K, where K is the number
of classes.
Keyword arguments:
- predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
predicted scores obtained from the model for N examples and K classes,
or an N-tensor/array of integer values between 0 and K-1.
- target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
ground-truth classes for N examples and K classes, or an N-tensor/array
of integer values between 0 and K-1.
"""
# If target and/or predicted are tensors, convert them to numpy arrays
if torch.is_tensor(predicted):
predicted = predicted.cpu().numpy()
if torch.is_tensor(target):
target = target.cpu().numpy()
assert predicted.shape[0] == target.shape[0], "number of targets and predicted outputs do not match"
if np.ndim(predicted) != 1:
assert (
predicted.shape[1] == self.num_classes
), "number of predictions does not match size of confusion matrix"
predicted = np.argmax(predicted, 1)
else:
assert (predicted.max() < self.num_classes) and (
predicted.min() >= 0
), "predicted values are not between 0 and k-1"
if np.ndim(target) != 1:
assert target.shape[1] == self.num_classes, "Onehot target does not match size of confusion matrix"
assert (target >= 0).all() and (target <= 1).all(), "in one-hot encoding, target values should be 0 or 1"
assert (target.sum(1) == 1).all(), "multi-label setting is not supported"
target = np.argmax(target, 1)
else:
assert (target.max() < self.num_classes) and (target.min() >= 0), "target values are not between 0 and k-1"
# hack for bincounting 2 arrays together
x = predicted + self.num_classes * target
bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes**2)
assert bincount_2d.size == self.num_classes**2
conf = bincount_2d.reshape((self.num_classes, self.num_classes))
self.conf += conf
def value(self):
"""
Returns:
Confustion matrix of K rows and K columns, where rows corresponds
to ground-truth targets and columns corresponds to predicted
targets.
"""
if self.normalized:
conf = self.conf.astype(np.float32)
return conf / conf.sum(1).clip(min=1e-12)[:, None]
else:
return self.conf
================================================
FILE: metric/iou.py
================================================
import numpy as np
from metric import metric
from metric.confusionmatrix import ConfusionMatrix
# Taken from https://github.com/davidtvs/PyTorch-ENet/tree/master/metric
class IoU(metric.Metric):
"""Computes the intersection over union (IoU) per class and corresponding
mean (mIoU).
Intersection over union (IoU) is a common evaluation metric for semantic
segmentation. The predictions are first accumulated in a confusion matrix
and the IoU is computed from it as follows:
IoU = true_positive / (true_positive + false_positive + false_negative).
Keyword arguments:
- num_classes (int): number of classes in the classification problem
- normalized (boolean, optional): Determines whether or not the confusion
matrix is normalized or not. Default: False.
- ignore_index (int or iterable, optional): Index of the classes to ignore
when computing the IoU. Can be an int, or any iterable of ints.
"""
def __init__(self, num_classes, normalized=False, ignore_index=None):
super().__init__()
self.conf_metric = ConfusionMatrix(num_classes, normalized)
if ignore_index is None:
self.ignore_index = None
elif isinstance(ignore_index, int):
self.ignore_index = (ignore_index,)
else:
try:
self.ignore_index = tuple(ignore_index)
except TypeError as err:
raise ValueError("'ignore_index' must be an int or iterable") from err
def reset(self):
self.conf_metric.reset()
def add(self, predicted, target):
"""Adds the predicted and target pair to the IoU metric.
Keyword arguments:
- predicted (Tensor): Can be a (N, K, H, W) tensor of
predicted scores obtained from the model for N examples and K classes,
or (N, H, W) tensor of integer values between 0 and K-1.
- target (Tensor): Can be a (N, K, H, W) tensor of
target scores for N examples and K classes, or (N, H, W) tensor of
integer values between 0 and K-1.
"""
# Dimensions check
assert predicted.size(0) == target.size(0), "number of targets and predicted outputs do not match"
assert (
predicted.dim() == 3 or predicted.dim() == 4
), "predictions must be of dimension (N, H, W) or (N, K, H, W)"
assert target.dim() == 3 or target.dim() == 4, "targets must be of dimension (N, H, W) or (N, K, H, W)"
# If the tensor is in categorical format convert it to integer format
if predicted.dim() == 4:
_, predicted = predicted.max(1)
if target.dim() == 4:
_, target = target.max(1)
self.conf_metric.add(predicted.view(-1), target.view(-1))
def value(self):
"""Computes the IoU and mean IoU.
The mean computation ignores NaN elements of the IoU array.
Returns:
Tuple: (IoU, mIoU). The first output is the per class IoU,
for K classes it's numpy.ndarray with K elements. The second output,
is the mean IoU.
"""
conf_matrix = self.conf_metric.value()
if self.ignore_index is not None:
conf_matrix[:, self.ignore_index] = 0
conf_matrix[self.ignore_index, :] = 0
true_positive = np.diag(conf_matrix)
false_positive = np.sum(conf_matrix, 0) - true_positive
false_negative = np.sum(conf_matrix, 1) - true_positive
# Just in case we get a division by 0, ignore/hide the error
with np.errstate(divide="ignore", invalid="ignore"):
iou = true_positive / (true_positive + false_positive + false_negative)
return iou, np.nanmean(iou)
================================================
FILE: metric/metric.py
================================================
from typing import Protocol
class Metric(Protocol):
"""Base class for all metrics."""
def reset(self):
pass
def add(self, predicted, target):
pass
def value(self):
pass
================================================
FILE: metric/precipitation_metrics.py
================================================
import torch
from torchmetrics import Metric
import numpy as np
class PrecipitationMetrics(Metric):
"""
A custom metric for precipitation forecasting that computes both regression metrics (MSE)
and classification metrics (precision, recall, F1, etc.) based on a threshold.
"""
def __init__(self, threshold=0.5, denormalize=True, dist_sync_on_step=False):
"""
Args:
threshold: Threshold to convert continuous predictions into binary masks (in mm/h)
denormalize: If True, applies the factor to undo normalization
dist_sync_on_step: Synchronize metric state across processes at each step
"""
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.threshold = threshold
self.denormalize = denormalize
self.factor = 47.83 # Factor to denormalize predictions (mm/h)
# Add states for regression metrics
self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_loss_denorm", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_samples", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total_pixels", default=torch.tensor(0), dist_reduce_fx="sum")
# Add states for classification metrics
self.add_state("total_tp", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total_fp", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total_tn", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total_fn", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds, target):
"""
Update the metric states with new predictions and targets.
Args:
preds: Model predictions (normalized)
target: Ground truth (normalized)
"""
# Check for NaN values
if torch.isnan(preds).any() or torch.isnan(target).any():
print("Warning: NaN values detected in predictions or targets")
return
# Make sure preds and target have the same shape
if preds.shape != target.shape:
if len(preds.shape) < len(target.shape):
preds = preds.unsqueeze(0)
elif len(preds.shape) > len(target.shape):
preds = preds.squeeze()
# If squeezing made it too small, add dimension back
if len(preds.shape) < len(target.shape):
preds = preds.unsqueeze(0)
# Calculate MSE loss (normalized)
batch_size = target.size(0)
loss = torch.nn.functional.mse_loss(preds, target, reduction="sum") / batch_size
self.total_loss += loss
self.total_samples += batch_size # Use batch_size instead of 1
self.total_pixels += target.numel()
# Denormalize if needed
if self.denormalize:
preds_updated = preds * self.factor
target_updated = target * self.factor
# Calculate denormalized MSE loss
loss_denorm = torch.nn.functional.mse_loss(preds_updated, target_updated, reduction="sum") / batch_size
self.total_loss_denorm += loss_denorm
else:
preds_updated = preds
target_updated = target
# Convert to mm/h for classification metrics (multiply by 12 for 5min to hourly rate)
preds_hourly = preds_updated * 12
target_hourly = target_updated * 12
# Apply threshold to get binary masks
preds_mask = preds_hourly > self.threshold
target_mask = target_hourly > self.threshold
# Compute confusion matrix
confusion = target_mask.view(-1) * 2 + preds_mask.view(-1)
bincount = torch.bincount(confusion, minlength=4)
# Update confusion matrix states
self.total_tn += bincount[0]
self.total_fp += bincount[1]
self.total_fn += bincount[2]
self.total_tp += bincount[3]
def compute(self):
"""
Compute the final metrics from the accumulated states.
Returns:
A dictionary containing all metrics.
"""
# Compute regression metrics
mse = self.total_loss / self.total_samples
mse_denorm = self.total_loss_denorm / self.total_samples if self.denormalize else torch.tensor(float('nan'))
mse_pixel = self.total_loss_denorm / self.total_pixels if self.denormalize else torch.tensor(float('nan'))
# Compute classification metrics
precision = self.total_tp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan'))
recall = self.total_tp / (self.total_tp + self.total_fn) if (self.total_tp + self.total_fn) > 0 else torch.tensor(float('nan'))
accuracy = (self.total_tp + self.total_tn) / (self.total_tp + self.total_tn + self.total_fp + self.total_fn)
# F1 score
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(float('nan'))
# Critical Success Index (CSI) or Threat Score
csi = self.total_tp / (self.total_tp + self.total_fn + self.total_fp) if (self.total_tp + self.total_fn + self.total_fp) > 0 else torch.tensor(float('nan'))
# False Alarm Ratio (FAR)
far = self.total_fp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan'))
# Heidke Skill Score (HSS)
denom = ((self.total_tp + self.total_fn) * (self.total_fn + self.total_tn) +
(self.total_tp + self.total_fp) * (self.total_fp + self.total_tn))
hss = ((self.total_tp * self.total_tn) - (self.total_fn * self.total_fp)) / denom if denom > 0 else torch.tensor(float('nan'))
return {
"mse": mse,
"mse_denorm": mse_denorm,
"mse_pixel": mse_pixel,
"precision": precision,
"recall": recall,
"accuracy": accuracy,
"f1": f1,
"csi": csi,
"far": far,
"hss": hss
}
================================================
FILE: models/SmaAt_UNet.py
================================================
from torch import nn
from models.unet_parts import OutConv
from models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS
from models.layers import CBAM
class SmaAt_UNet(nn.Module):
def __init__(
self,
n_channels,
n_classes,
kernels_per_layer=2,
bilinear=True,
reduction_ratio=16,
):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
kernels_per_layer = kernels_per_layer
self.bilinear = bilinear
reduction_ratio = reduction_ratio
self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
factor = 2 if self.bilinear else 1
self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)
self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)
self.outc = OutConv(64, self.n_classes)
def forward(self, x):
x1 = self.inc(x)
x1Att = self.cbam1(x1)
x2 = self.down1(x1)
x2Att = self.cbam2(x2)
x3 = self.down2(x2)
x3Att = self.cbam3(x3)
x4 = self.down3(x3)
x4Att = self.cbam4(x4)
x5 = self.down4(x4)
x5Att = self.cbam5(x5)
x = self.up1(x5Att, x4Att)
x = self.up2(x, x3Att)
x = self.up3(x, x2Att)
x = self.up4(x, x1Att)
logits = self.outc(x)
return logits
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/layers.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
# Taken from https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14
class DepthToSpace(nn.Module):
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, self.bs, self.bs, C // (self.bs**2), H, W) # (N, bs, bs, C//bs^2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
x = x.view(N, C // (self.bs**2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
return x
class SpaceToDepth(nn.Module):
# Expects the following shape: Batch, Channel, Height, Width
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * (self.bs**2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
return x
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, output_channels, kernel_size, padding=0, kernels_per_layer=1):
super().__init__()
# In Tensorflow DepthwiseConv2D has depth_multiplier instead of kernels_per_layer
self.depthwise = nn.Conv2d(
in_channels,
in_channels * kernels_per_layer,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
)
self.pointwise = nn.Conv2d(in_channels * kernels_per_layer, output_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class DoubleDense(nn.Module):
def __init__(self, in_channels, hidden_neurons, output_channels):
super().__init__()
self.dense1 = nn.Linear(in_channels, out_features=hidden_neurons)
self.dense2 = nn.Linear(in_features=hidden_neurons, out_features=hidden_neurons // 2)
self.dense3 = nn.Linear(in_features=hidden_neurons // 2, out_features=output_channels)
def forward(self, x):
out = F.relu(self.dense1(x.view(x.size(0), -1)))
out = F.relu(self.dense2(out))
out = self.dense3(out)
return out
class DoubleDSConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_ds_conv = nn.Sequential(
DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
DepthwiseSeparableConv(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_ds_conv(x)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelAttention(nn.Module):
def __init__(self, input_channels, reduction_ratio=16):
super().__init__()
self.input_channels = input_channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py
# uses Convolutions instead of Linear
self.MLP = nn.Sequential(
Flatten(),
nn.Linear(input_channels, input_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(input_channels // reduction_ratio, input_channels),
)
def forward(self, x):
# Take the input and apply average and max pooling
avg_values = self.avg_pool(x)
max_values = self.max_pool(x)
out = self.MLP(avg_values) + self.MLP(max_values)
scale = x * torch.sigmoid(out).unsqueeze(2).unsqueeze(3).expand_as(x)
return scale
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(1)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv(out)
out = self.bn(out)
scale = x * torch.sigmoid(out)
return scale
class CBAM(nn.Module):
def __init__(self, input_channels, reduction_ratio=16, kernel_size=7):
super().__init__()
self.channel_att = ChannelAttention(input_channels, reduction_ratio=reduction_ratio)
self.spatial_att = SpatialAttention(kernel_size=kernel_size)
def forward(self, x):
out = self.channel_att(x)
out = self.spatial_att(out)
return out
================================================
FILE: models/regression_lightning.py
================================================
import lightning.pytorch as pl
import torch
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from utils import dataset_precip
import argparse
import numpy as np
from metric.precipitation_metrics import PrecipitationMetrics
from utils.formatting import make_metrics_str
class UNetBase(pl.LightningModule):
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--model",
type=str,
default="UNet",
choices=["UNet", "UNetDS", "UNetAttention", "UNetDSAttention", "PersistenceModel"],
)
parser.add_argument("--n_channels", type=int, default=12)
parser.add_argument("--n_classes", type=int, default=1)
parser.add_argument("--kernels_per_layer", type=int, default=1)
parser.add_argument("--bilinear", type=bool, default=True)
parser.add_argument("--reduction_ratio", type=int, default=16)
parser.add_argument("--lr_patience", type=int, default=5)
parser.add_argument("--threshold", type=float, default=0.5)
return parser
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
self.train_metrics = PrecipitationMetrics(
threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5
)
self.val_metrics = PrecipitationMetrics(
threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5
)
self.test_metrics = PrecipitationMetrics(
threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5
)
def forward(self, x):
pass
def configure_optimizers(self):
opt = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = {
"scheduler": optim.lr_scheduler.ReduceLROnPlateau(
opt, mode="min", factor=0.1, patience=self.hparams.lr_patience
),
"monitor": "val_loss", # Default: val_loss
}
return [opt], [scheduler]
def loss_func(self, y_pred, y_true):
# Ensure consistent shapes before computing loss
if y_pred.dim() > y_true.dim():
y_pred = y_pred.squeeze(1)
elif y_true.dim() > y_pred.dim():
y_pred = y_pred.unsqueeze(1)
# reduction="mean" is average of every pixel, but I want average of image
return nn.functional.mse_loss(y_pred, y_true, reduction="sum") / y_true.size(0)
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = self.loss_func(y_pred, y)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
# Update training metrics with detached tensors
self.train_metrics.update(y_pred.detach(), y.detach())
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = self.loss_func(y_pred, y)
self.log("val_loss", loss, prog_bar=True)
# Update validation metrics with detached tensors
self.val_metrics.update(y_pred.detach(), y.detach())
def test_step(self, batch, batch_idx):
"""Calculate the loss (MSE per default) and other metrics on the test set normalized and denormalized."""
x, y = batch
y_pred = self(x)
# Update the test metrics
self.test_metrics.update(y_pred.detach(), y.detach())
def on_test_epoch_end(self):
"""Compute and log all metrics at the end of the test epoch."""
test_metrics_dict = self.test_metrics.compute()
# Print all metrics in one line
print(f"\n\nEpoch {self.current_epoch} - Test Metrics: {make_metrics_str(test_metrics_dict)}")
# Reset the metrics for the next test epoch
self.test_metrics.reset()
def on_train_epoch_end(self):
"""Compute and log all metrics at the end of the training epoch."""
train_metrics_dict = self.train_metrics.compute()
print(f"\n\nEpoch {self.current_epoch} - Train Metrics: {make_metrics_str(train_metrics_dict)}")
self.train_metrics.reset()
def on_validation_epoch_end(self):
"""Compute and log all metrics at the end of the validation epoch."""
val_metrics_dict = self.val_metrics.compute()
print(f"\n\nEpoch {self.current_epoch} - Validation Metrics: {make_metrics_str(val_metrics_dict)}")
self.val_metrics.reset()
class PrecipRegressionBase(UNetBase):
@staticmethod
def add_model_specific_args(parent_parser):
parent_parser = UNetBase.add_model_specific_args(parent_parser)
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--num_input_images", type=int, default=12)
parser.add_argument("--num_output_images", type=int, default=6)
parser.add_argument("--valid_size", type=float, default=0.1)
parser.add_argument("--use_oversampled_dataset", type=bool, default=True)
parser.n_channels = parser.parse_args().num_input_images
parser.n_classes = 1
return parser
def __init__(self, hparams):
super().__init__(hparams=hparams)
self.train_dataset = None
self.valid_dataset = None
self.train_sampler = None
self.valid_sampler = None
def prepare_data(self):
# train_transform = transforms.Compose([
# transforms.RandomHorizontalFlip()]
# )
train_transform = None
valid_transform = None
precip_dataset = (
dataset_precip.precipitation_maps_oversampled_h5
if self.hparams.use_oversampled_dataset
else dataset_precip.precipitation_maps_h5
)
self.train_dataset = precip_dataset(
in_file=self.hparams.dataset_folder,
num_input_images=self.hparams.num_input_images,
num_output_images=self.hparams.num_output_images,
train=True,
transform=train_transform,
)
self.valid_dataset = precip_dataset(
in_file=self.hparams.dataset_folder,
num_input_images=self.hparams.num_input_images,
num_output_images=self.hparams.num_output_images,
train=True,
transform=valid_transform,
)
num_train = len(self.train_dataset)
indices = list(range(num_train))
split = int(np.floor(self.hparams.valid_size * num_train))
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
self.train_sampler = SubsetRandomSampler(train_idx)
self.valid_sampler = SubsetRandomSampler(valid_idx)
def train_dataloader(self):
train_loader = DataLoader(
self.train_dataset,
batch_size=self.hparams.batch_size,
sampler=self.train_sampler,
pin_memory=True,
# The following can/should be tweaked depending on the number of CPU cores
num_workers=1,
persistent_workers=True,
)
return train_loader
def val_dataloader(self):
valid_loader = DataLoader(
self.valid_dataset,
batch_size=self.hparams.batch_size,
sampler=self.valid_sampler,
pin_memory=True,
# The following can/should be tweaked depending on the number of CPU cores
num_workers=1,
persistent_workers=True,
)
return valid_loader
class PersistenceModel(UNetBase):
def forward(self, x):
return x[:, -1:, :, :]
================================================
FILE: models/unet_parts.py
================================================
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
================================================
FILE: models/unet_parts_depthwise_separable.py
================================================
""" Parts of the U-Net model """
# Base model taken from: https://github.com/milesial/Pytorch-UNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.layers import DepthwiseSeparableConv
class DoubleConvDS(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None, kernels_per_layer=1):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
DepthwiseSeparableConv(
in_channels,
mid_channels,
kernel_size=3,
kernels_per_layer=kernels_per_layer,
padding=1,
),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
DepthwiseSeparableConv(
mid_channels,
out_channels,
kernel_size=3,
kernels_per_layer=kernels_per_layer,
padding=1,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
class DownDS(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels, kernels_per_layer=1):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer),
)
def forward(self, x):
return self.maxpool_conv(x)
class UpDS(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True, kernels_per_layer=1):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConvDS(
in_channels,
out_channels,
in_channels // 2,
kernels_per_layer=kernels_per_layer,
)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
================================================
FILE: models/unet_precip_regression_lightning.py
================================================
from models.unet_parts import Down, DoubleConv, Up, OutConv
from models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS
from models.layers import CBAM
from models.regression_lightning import PrecipRegressionBase, PersistenceModel
class UNet(PrecipRegressionBase):
def __init__(self, hparams):
super().__init__(hparams=hparams)
self.n_channels = self.hparams.n_channels
self.n_classes = self.hparams.n_classes
self.bilinear = self.hparams.bilinear
self.inc = DoubleConv(self.n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if self.bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, self.bilinear)
self.up2 = Up(512, 256 // factor, self.bilinear)
self.up3 = Up(256, 128 // factor, self.bilinear)
self.up4 = Up(128, 64, self.bilinear)
self.outc = OutConv(64, self.n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
class UNetAttention(PrecipRegressionBase):
def __init__(self, hparams):
super().__init__(hparams=hparams)
self.n_channels = self.hparams.n_channels
self.n_classes = self.hparams.n_classes
self.bilinear = self.hparams.bilinear
reduction_ratio = self.hparams.reduction_ratio
self.inc = DoubleConv(self.n_channels, 64)
self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
self.down1 = Down(64, 128)
self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
self.down2 = Down(128, 256)
self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
self.down3 = Down(256, 512)
self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
factor = 2 if self.bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)
self.up1 = Up(1024, 512 // factor, self.bilinear)
self.up2 = Up(512, 256 // factor, self.bilinear)
self.up3 = Up(256, 128 // factor, self.bilinear)
self.up4 = Up(128, 64, self.bilinear)
self.outc = OutConv(64, self.n_classes)
def forward(self, x):
x1 = self.inc(x)
x1Att = self.cbam1(x1)
x2 = self.down1(x1)
x2Att = self.cbam2(x2)
x3 = self.down2(x2)
x3Att = self.cbam3(x3)
x4 = self.down3(x3)
x4Att = self.cbam4(x4)
x5 = self.down4(x4)
x5Att = self.cbam5(x5)
x = self.up1(x5Att, x4Att)
x = self.up2(x, x3Att)
x = self.up3(x, x2Att)
x = self.up4(x, x1Att)
logits = self.outc(x)
return logits
class UNetDS(PrecipRegressionBase):
def __init__(self, hparams):
super().__init__(hparams=hparams)
self.n_channels = self.hparams.n_channels
self.n_classes = self.hparams.n_classes
self.bilinear = self.hparams.bilinear
kernels_per_layer = self.hparams.kernels_per_layer
self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
factor = 2 if self.bilinear else 1
self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)
self.outc = OutConv(64, self.n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
class UNetDSAttention(PrecipRegressionBase):
def __init__(self, hparams):
super().__init__(hparams=hparams)
self.n_channels = self.hparams.n_channels
self.n_classes = self.hparams.n_classes
self.bilinear = self.hparams.bilinear
reduction_ratio = self.hparams.reduction_ratio
kernels_per_layer = self.hparams.kernels_per_layer
self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
factor = 2 if self.bilinear else 1
self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)
self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)
self.outc = OutConv(64, self.n_classes)
def forward(self, x):
x1 = self.inc(x)
x1Att = self.cbam1(x1)
x2 = self.down1(x1)
x2Att = self.cbam2(x2)
x3 = self.down2(x2)
x3Att = self.cbam3(x3)
x4 = self.down3(x3)
x4Att = self.cbam4(x4)
x5 = self.down4(x4)
x5Att = self.cbam5(x5)
x = self.up1(x5Att, x4Att)
x = self.up2(x, x3Att)
x = self.up3(x, x2Att)
x = self.up4(x, x1Att)
logits = self.outc(x)
return logits
class UNetDSAttention4CBAMs(PrecipRegressionBase):
def __init__(self, hparams):
super().__init__(hparams=hparams)
self.n_channels = self.hparams.n_channels
self.n_classes = self.hparams.n_classes
self.bilinear = self.hparams.bilinear
reduction_ratio = self.hparams.reduction_ratio
kernels_per_layer = self.hparams.kernels_per_layer
self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)
self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)
self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)
self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)
self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
factor = 2 if self.bilinear else 1
self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)
self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)
self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)
self.outc = OutConv(64, self.n_classes)
def forward(self, x):
x1 = self.inc(x)
x1Att = self.cbam1(x1)
x2 = self.down1(x1)
x2Att = self.cbam2(x2)
x3 = self.down2(x2)
x3Att = self.cbam3(x3)
x4 = self.down3(x3)
x4Att = self.cbam4(x4)
x5 = self.down4(x4)
x = self.up1(x5, x4Att)
x = self.up2(x, x3Att)
x = self.up3(x, x2Att)
x = self.up4(x, x1Att)
logits = self.outc(x)
return logits
================================================
FILE: plot_examples.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import models.unet_precip_regression_lightning as unet_regr\n",
"from utils import dataset_precip\n",
"\n",
"dataset = dataset_precip.precipitation_maps_oversampled_h5(\n",
" in_file=\"data/precipitation/train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5\",\n",
" num_input_images=12,\n",
" num_output_images=6, train=False)\n",
"# 222, 444, 777, 1337 is fine\n",
"x, y_true = dataset[777]\n",
"# add batch dimension\n",
"x = torch.tensor(x).unsqueeze(0)\n",
"y_true = torch.tensor(y_true).unsqueeze(0)\n",
"print(x.shape, y_true.shape)\n",
"\n",
"models_dir = \"checkpoints/comparison\"\n",
"precip_models = [m for m in os.listdir(models_dir) if \".ckpt\" in m]\n",
"print(precip_models)\n",
"\n",
"loss_func = nn.MSELoss(reduction=\"sum\")\n",
"print(\"Persistence\", loss_func(x.squeeze()[-1]*47.83, y_true.squeeze()*47.83).item())\n",
"\n",
"plot_images = dict()\n",
"for i, model_file in enumerate(precip_models):\n",
" # model_name = \"\".join(model_file.split(\"_\")[0])\n",
" if \"UNetAttention\" in model_file:\n",
" model_name = \"UNet with CBAM\"\n",
" model = unet_regr.UNetAttention\n",
" elif \"UNetDSAttention4CBAMs\" in model_file:\n",
" model_name = \"UNetDS Attention 4CBAMs\"\n",
" model = unet_regr.UNetDSAttention4CBAMs\n",
" elif \"UNetDSAttention\" in model_file:\n",
" model_name = \"SmaAt-UNet\"\n",
" model = unet_regr.UNetDSAttention\n",
" elif \"UNetDS\" in model_file:\n",
" model_name = \"UNet with DSC\"\n",
" model = unet_regr.UNetDS\n",
" elif \"UNet\" in model_file:\n",
" model_name = \"UNet\"\n",
" model = unet_regr.UNet\n",
" else:\n",
" raise NotImplementedError(f\"Model not found\")\n",
" model = model.load_from_checkpoint(f\"{models_dir}/{model_file}\")\n",
" # loaded = torch.load(f\"{models_dir}/{model_file}\")\n",
" # model = loaded[\"model\"]\n",
" # model.load_state_dict(loaded[\"state_dict\"])\n",
" # loss_func = nn.CrossEntropyLoss(weight=torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.]))\n",
"\n",
" model.to(\"cpu\")\n",
" model.eval()\n",
"\n",
" # get prediction of model\n",
" with torch.no_grad():\n",
" y_pred = model(x)\n",
" y_pred = y_pred.squeeze()\n",
" print(model_name, loss_func(y_pred*47.83, y_true.squeeze()*47.83).item())\n",
"\n",
" plot_images[model_name] = y_pred"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"#### Plotting ####\n",
"plt.figure(figsize=(10, 20))\n",
"plt.subplot(len(precip_models)+2, 1, 1)\n",
"plt.imshow(y_true.squeeze()*47.83)\n",
"plt.colorbar()\n",
"plt.title(\"Ground truth\")\n",
"plt.axis(\"off\")\n",
"\n",
"plt.subplot(len(precip_models)+2, 1, 2)\n",
"plt.imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
"plt.colorbar()\n",
"plt.title(\"Persistence\")\n",
"plt.axis(\"off\")\n",
"\n",
"for i, (model_name, y_pred) in enumerate(plot_images.items()):\n",
" print(model_name)\n",
" plt.subplot(len(precip_models)+2, 1, i+3)\n",
" plt.imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
" plt.colorbar()\n",
" plt.title(model_name)\n",
" plt.axis(\"off\")\n",
"# plt.colorbar()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"fig, axesgrid = plt.subplots(nrows=2, ncols=(len(precip_models)+2)//2, figsize=(20, 10))\n",
"axes = axesgrid.flat\n",
"im = axes[0].imshow(y_true.squeeze()*47.83)\n",
"axes[0].set_title(\"Groundtruth\", fontsize=15)\n",
"axes[0].set_axis_off()\n",
"axes[1].imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
"axes[1].set_title(\"Persistence\", fontsize=15)\n",
"axes[1].set_axis_off()\n",
"\n",
"# for ax in axes.flat[2:]:\n",
"for i, (model_name, y_pred) in enumerate(plot_images.items()):\n",
" print(model_name)\n",
"# plt.subplot(len(precip_models)+2, 1, i+3)\n",
" axes[i+2].imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\n",
" axes[i+2].set_title(model_name, fontsize=15)\n",
" axes[i+2].set_axis_off()\n",
"# im = ax.imshow(np.random.random((16, 16)), cmap='viridis',\n",
"# vmin=0, vmax=1)\n",
"\n",
"output_dir = \"plots\"\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"fig.subplots_adjust(wspace=0.02)\n",
"cbar = fig.colorbar(im, ax=axesgrid.ravel().tolist(), shrink=1)\n",
"cbar.ax.set_ylabel('mm/5min', fontsize=15)\n",
"plt.savefig(os.path.join(output_dir, \"Precip_example.eps\"))\n",
"plt.savefig(os.path.join(output_dir, \"Precip_example.png\"), dpi=300)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: pyproject.toml
================================================
[project]
name = "smaat-unet"
version = "0.1.0"
description = "Code for the paper `SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture`"
authors = [{ name = "Kevin Trebing", email = "Kevin.Trebing@gmx.net" }]
readme = "README.md"
requires-python = ">=3.10.11"
dependencies = [
"h5py>=3.13.0",
"lightning>=2.5.0.post0",
"matplotlib>=3.10.0",
"numpy>=2.2.3",
"tensorboard>=2.19.0",
"torch>=2.6.0",
"torchsummary>=1.5.1",
"tqdm>=4.67.1",
]
[dependency-groups]
dev = [
"mypy>=1.15.0",
"pre-commit>=4.1.0",
"ruff>=0.9.7",
]
scripts = [
"ipykernel>=6.29.5",
]
[[tool.uv.index]]
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
explicit = true
[tool.uv.sources]
torch = [
{ index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [
{ index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
[tool.mypy]
python_version = "3.10"
ignore_missing_imports = true
[tool.ruff]
line-length = 120
[tool.ruff.lint]
select = [
"E", # pycodestyle Errors
"W", # pycodestyle Warnings
"F", # pyflakes
"UP", # pyupgrade
# "D", # pydocstyle
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"A", # flake8-builtins
"C90", # mccabe complexity
"DJ", # flake8-django
"PIE", # flake8-pie
# "SIM", # flake8-simplify
]
ignore = [
"B905", # length-checking in zip() only introduced with Python 3.10 (PEP618)
"UP007", # Python version 3.9 does not allow writing union types as X | Y
"UP038", # Python version 3.9 does not allow writing union types as X | Y
"D202", # No blank lines allowed after function docstring
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"D203", # One blank line before class docstring
"D212", # Multi-line docstring summary should start at the first line
"D213", # Multi-line docstring summary should start at the second line
"D400", # First line of docstring should end with a period
"D404", # First word of the docstring should not be "This"
"D415", # First line should end with a period, question mark, or exclamation point
"DJ001", # Avoid using `null=True` on string-based fields
]
exclude = [
".git",
".local",
".cache",
".venv",
"./venv",
".vscode",
"__pycache__",
"docs",
"build",
"dist",
"notebooks",
"migrations"
]
[tool.ruff.lint.mccabe]
# Flag errors (`C901`) whenever the complexity level exceeds 12.
max-complexity = 12
================================================
FILE: requirements-dev.txt
================================================
# This file was autogenerated by uv via the following command:
# uv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt
absl-py==2.1.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
async-timeout==5.0.1 ; python_full_version < '3.11'
attrs==25.1.0
cfgv==3.4.0
colorama==0.4.6 ; sys_platform == 'win32'
contourpy==1.3.1
cycler==0.12.1
distlib==0.3.9
filelock==3.17.0
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2025.2.0
grpcio==1.70.0
h5py==3.13.0
identify==2.6.7
idna==3.10
jinja2==3.1.5
kiwisolver==1.4.8
lightning==2.5.0.post0
lightning-utilities==0.12.0
markdown==3.7
markupsafe==3.0.2
matplotlib==3.10.0
mpmath==1.3.0
multidict==6.1.0
mypy==1.15.0
mypy-extensions==1.0.0
networkx==3.4.2
nodeenv==1.9.1
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
packaging==24.2
pillow==11.1.0
platformdirs==4.3.6
pre-commit==4.1.0
propcache==0.3.0
protobuf==5.29.3
pyparsing==3.2.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pyyaml==6.0.2
ruff==0.9.7
setuptools==75.8.0
six==1.17.0
sympy==1.13.1
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tomli==2.2.1 ; python_full_version < '3.11'
torch==2.6.0
torchmetrics==1.6.1
torchsummary==1.5.1
tqdm==4.67.1
triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typing-extensions==4.12.2
virtualenv==20.29.2
werkzeug==3.1.3
yarl==1.18.3
================================================
FILE: requirements-full.txt
================================================
# This file was autogenerated by uv via the following command:
# uv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt
absl-py==2.1.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
appnope==0.1.4 ; sys_platform == 'darwin'
asttokens==3.0.0
async-timeout==5.0.1 ; python_full_version < '3.11'
attrs==25.1.0
cffi==1.17.1 ; implementation_name == 'pypy'
cfgv==3.4.0
colorama==0.4.6 ; sys_platform == 'win32'
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.12
decorator==5.1.1
distlib==0.3.9
exceptiongroup==1.2.2 ; python_full_version < '3.11'
executing==2.2.0
filelock==3.17.0
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2025.2.0
grpcio==1.70.0
h5py==3.13.0
identify==2.6.7
idna==3.10
ipykernel==6.29.5
ipython==8.32.0
jedi==0.19.2
jinja2==3.1.5
jupyter-client==8.6.3
jupyter-core==5.7.2
kiwisolver==1.4.8
lightning==2.5.0.post0
lightning-utilities==0.12.0
markdown==3.7
markupsafe==3.0.2
matplotlib==3.10.0
matplotlib-inline==0.1.7
mpmath==1.3.0
multidict==6.1.0
mypy==1.15.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
packaging==24.2
parso==0.8.4
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
pillow==11.1.0
platformdirs==4.3.6
pre-commit==4.1.0
prompt-toolkit==3.0.50
propcache==0.3.0
protobuf==5.29.3
psutil==7.0.0
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
pure-eval==0.2.3
pycparser==2.22 ; implementation_name == 'pypy'
pygments==2.19.1
pyparsing==3.2.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
pyyaml==6.0.2
pyzmq==26.2.1
ruff==0.9.7
setuptools==75.8.0
six==1.17.0
stack-data==0.6.3
sympy==1.13.1
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tomli==2.2.1 ; python_full_version < '3.11'
torch==2.6.0
torchmetrics==1.6.1
torchsummary==1.5.1
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typing-extensions==4.12.2
virtualenv==20.29.2
wcwidth==0.2.13
werkzeug==3.1.3
yarl==1.18.3
================================================
FILE: requirements.txt
================================================
# This file was autogenerated by uv via the following command:
# uv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt
absl-py==2.1.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
async-timeout==5.0.1 ; python_full_version < '3.11'
attrs==25.1.0
colorama==0.4.6 ; sys_platform == 'win32'
contourpy==1.3.1
cycler==0.12.1
filelock==3.17.0
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2025.2.0
grpcio==1.70.0
h5py==3.13.0
idna==3.10
jinja2==3.1.5
kiwisolver==1.4.8
lightning==2.5.0.post0
lightning-utilities==0.12.0
markdown==3.7
markupsafe==3.0.2
matplotlib==3.10.0
mpmath==1.3.0
multidict==6.1.0
networkx==3.4.2
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
packaging==24.2
pillow==11.1.0
propcache==0.3.0
protobuf==5.29.3
pyparsing==3.2.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pyyaml==6.0.2
setuptools==75.8.0
six==1.17.0
sympy==1.13.1
tensorboard==2.19.0
tensorboard-data-server==0.7.2
torch==2.6.0
torchmetrics==1.6.1
torchsummary==1.5.1
tqdm==4.67.1
triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typing-extensions==4.12.2
werkzeug==3.1.3
yarl==1.18.3
================================================
FILE: root.py
================================================
from pathlib import Path
ROOT_DIR = Path(__file__).parent
================================================
FILE: train_SmaAtUNet.py
================================================
from typing import Optional
from models.SmaAt_UNet import SmaAt_UNet
import torch
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from torchvision import transforms
from root import ROOT_DIR
from utils import dataset_VOC
import time
from tqdm import tqdm
from metric import iou
import os
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
def fit(
epochs,
model,
loss_func,
opt,
train_dl,
valid_dl,
dev=None,
save_every: Optional[int] = None,
tensorboard: bool = False,
earlystopping=None,
lr_scheduler=None,
):
writer = None
if tensorboard:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(comment=f"{model.__class__.__name__}")
if dev is None:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
start_time = time.time()
best_mIoU = -1.0
earlystopping_counter = 0
for epoch in tqdm(range(epochs), desc="Epochs", leave=True):
model.train()
train_loss = 0.0
for _, (xb, yb) in enumerate(tqdm(train_dl, desc="Batches", leave=False)):
# for i, (xb, yb) in enumerate(train_dl):
loss = loss_func(model(xb.to(dev)), yb.to(dev))
opt.zero_grad()
loss.backward()
opt.step()
train_loss += loss.item()
# if i > 100:
# break
train_loss /= len(train_dl)
# Reduce learning rate after epoch
# scheduler.step()
# Calc validation loss
val_loss = 0.0
iou_metric = iou.IoU(21, normalized=False)
model.eval()
with torch.no_grad():
for xb, yb in tqdm(valid_dl, desc="Validation", leave=False):
# for xb, yb in valid_dl:
y_pred = model(xb.to(dev))
loss = loss_func(y_pred, yb.to(dev))
val_loss += loss.item()
# Calculate mean IOU
pred_class = torch.argmax(nn.functional.softmax(y_pred, dim=1), dim=1)
iou_metric.add(pred_class, target=yb)
iou_class, mean_iou = iou_metric.value()
val_loss /= len(valid_dl)
# Save the model with the best mean IoU
if mean_iou > best_mIoU:
os.makedirs("checkpoints", exist_ok=True)
torch.save(
{
"model": model,
"epoch": epoch,
"state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"val_loss": val_loss,
"train_loss": train_loss,
"mIOU": mean_iou,
},
ROOT_DIR / "checkpoints" / f"best_mIoU_model_{model.__class__.__name__}.pt",
)
best_mIoU = mean_iou
earlystopping_counter = 0
else:
earlystopping_counter += 1
if earlystopping is not None and earlystopping_counter >= earlystopping:
print(f"Stopping early --> mean IoU has not decreased over {earlystopping} epochs")
break
print(
f"Epoch: {epoch:5d}, Time: {(time.time() - start_time) / 60:.3f} min,"
f"Train_loss: {train_loss:2.10f}, Val_loss: {val_loss:2.10f},",
f"mIOU: {mean_iou:.10f},",
f"lr: {get_lr(opt)},",
f"Early stopping counter: {earlystopping_counter}/{earlystopping}" if earlystopping is not None else "",
)
if writer:
# add to tensorboard
writer.add_scalar("Loss/train", train_loss, epoch)
writer.add_scalar("Loss/val", val_loss, epoch)
writer.add_scalar("Metric/mIOU", mean_iou, epoch)
writer.add_scalar("Parameters/learning_rate", get_lr(opt), epoch)
if save_every is not None and epoch % save_every == 0:
# save model
torch.save(
{
"model": model,
"epoch": epoch,
"state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
# 'scheduler_state_dict': scheduler.state_dict(),
"val_loss": val_loss,
"train_loss": train_loss,
"mIOU": mean_iou,
},
ROOT_DIR / "checkpoints" / f"model_{model.__class__.__name__}_epoch_{epoch}.pt",
)
if lr_scheduler is not None:
lr_scheduler.step(mean_iou)
if __name__ == "__main__":
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dataset_folder = ROOT_DIR / "data" / "VOCdevkit"
batch_size = 8
learning_rate = 0.001
epochs = 200
earlystopping = 30
save_every = 1
# Load your dataset here
transformations = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
voc_dataset_train = dataset_VOC.VOCSegmentation(
root=dataset_folder,
image_set="train",
transformations=transformations,
augmentations=True,
)
voc_dataset_val = dataset_VOC.VOCSegmentation(
root=dataset_folder,
image_set="val",
transformations=transformations,
augmentations=False,
)
train_dl = DataLoader(
dataset=voc_dataset_train,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
)
valid_dl = DataLoader(
dataset=voc_dataset_val,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
)
# Load SmaAt-UNet
model = SmaAt_UNet(n_channels=3, n_classes=21)
# Move model to device
model.to(dev)
# Define Optimizer and loss
opt = optim.Adam(model.parameters(), lr=learning_rate)
loss_func = nn.CrossEntropyLoss().to(dev)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.1, patience=4)
# Train network
fit(
epochs=epochs,
model=model,
loss_func=loss_func,
opt=opt,
train_dl=train_dl,
valid_dl=valid_dl,
dev=dev,
save_every=save_every,
tensorboard=True,
earlystopping=earlystopping,
lr_scheduler=lr_scheduler,
)
================================================
FILE: train_precip_lightning.py
================================================
from root import ROOT_DIR
import lightning.pytorch as pl
from lightning.pytorch.callbacks import (
ModelCheckpoint,
LearningRateMonitor,
EarlyStopping,
)
from lightning.pytorch import loggers
import argparse
from models import unet_precip_regression_lightning as unet_regr
from lightning.pytorch.tuner import Tuner
def train_regression(hparams, find_batch_size_automatically: bool = False):
if hparams.model == "UNetDSAttention":
net = unet_regr.UNetDSAttention(hparams=hparams)
elif hparams.model == "UNetAttention":
net = unet_regr.UNetAttention(hparams=hparams)
elif hparams.model == "UNet":
net = unet_regr.UNet(hparams=hparams)
elif hparams.model == "UNetDS":
net = unet_regr.UNetDS(hparams=hparams)
else:
raise NotImplementedError(f"Model '{hparams.model}' not implemented")
default_save_path = ROOT_DIR / "lightning" / "precip_regression"
checkpoint_callback = ModelCheckpoint(
dirpath=default_save_path / net.__class__.__name__,
filename=net.__class__.__name__ + "_rain_threshold_50_{epoch}-{val_loss:.6f}",
save_top_k=1,
verbose=False,
monitor="val_loss",
mode="min",
)
last_checkpoint_callback = ModelCheckpoint(
dirpath=default_save_path / net.__class__.__name__,
filename=net.__class__.__name__ + "_rain_threshold_50_{epoch}-{val_loss:.6f}_last",
save_top_k=1,
verbose=False,
)
lr_monitor = LearningRateMonitor()
tb_logger = loggers.TensorBoardLogger(save_dir=default_save_path, name=net.__class__.__name__)
earlystopping_callback = EarlyStopping(
monitor="val_loss",
mode="min",
patience=hparams.es_patience,
)
trainer = pl.Trainer(
accelerator="gpu",
devices=1,
fast_dev_run=hparams.fast_dev_run,
max_epochs=hparams.epochs,
default_root_dir=default_save_path,
logger=tb_logger,
callbacks=[checkpoint_callback, last_checkpoint_callback, earlystopping_callback, lr_monitor],
val_check_interval=hparams.val_check_interval,
)
if find_batch_size_automatically:
tuner = Tuner(trainer)
# Auto-scale batch size by growing it exponentially (default)
tuner.scale_batch_size(net, mode="binsearch")
# This can be used to speed up training with newer GPUs:
# https://lightning.ai/docs/pytorch/stable/advanced/speed.html#low-precision-matrix-multiplication
# torch.set_float32_matmul_precision('medium')
trainer.fit(model=net, ckpt_path=hparams.resume_from_checkpoint)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = unet_regr.PrecipRegressionBase.add_model_specific_args(parser)
parser.add_argument(
"--dataset_folder",
default=ROOT_DIR / "data" / "precipitation" / "RAD_NL25_RAC_5min_train_test_2016-2019.h5",
type=str,
)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--learning_rate", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--val_check_interval", type=float, default=None)
args = parser.parse_args()
# args.fast_dev_run = True
args.n_channels = 12
# args.gpus = 1
args.model = "UNetDSAttention"
args.lr_patience = 4
args.es_patience = 15
# args.val_check_interval = 0.25
args.kernels_per_layer = 2
args.use_oversampled_dataset = True
args.dataset_folder = (
ROOT_DIR / "data" / "precipitation" / "train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5"
)
# args.resume_from_checkpoint = f"lightning/precip_regression/{args.model}/UNetDSAttention.ckpt"
# train_regression(args, find_batch_size_automatically=False)
# All the models below will be trained
for m in ["UNet", "UNetDS", "UNetAttention", "UNetDSAttention"]:
args.model = m
print(f"Start training model: {m}")
train_regression(args, find_batch_size_automatically=False)
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/data_loader_precip.py
================================================
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from utils import dataset_precip
from root import ROOT_DIR
# Taken from: https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb
def get_train_valid_loader(
data_dir,
batch_size,
random_seed,
num_input_images,
num_output_images,
augment,
classification,
valid_size=0.1,
shuffle=True,
num_workers=4,
pin_memory=False,
):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert (valid_size >= 0) and (valid_size <= 1), error_msg
# Since I am not dealing with RGB images I do not need this
# normalize = transforms.Normalize(
# mean=[0.4914, 0.4822, 0.4465],
# std=[0.2023, 0.1994, 0.2010],
# )
# define transforms
valid_transform = None
# valid_transform = transforms.Compose([
# transforms.ToTensor(),
# normalize,
# ])
if augment:
# TODO flipping, rotating, sequence flipping (torch.flip seems very expensive)
train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# normalize,
]
)
else:
train_transform = None
# train_transform = transforms.Compose([
# transforms.ToTensor(),
# normalize,
# ])
if classification:
# load the dataset
train_dataset = dataset_precip.precipitation_maps_classification_h5(
in_file=data_dir,
num_input_images=num_input_images,
img_to_predict=num_output_images,
train=True,
transform=train_transform,
)
valid_dataset = dataset_precip.precipitation_maps_classification_h5(
in_file=data_dir,
num_input_images=num_input_images,
img_to_predict=num_output_images,
train=True,
transform=valid_transform,
)
else:
# load the dataset
train_dataset = dataset_precip.precipitation_maps_h5(
in_file=data_dir,
num_input_images=num_input_images,
num_output_images=num_output_images,
train=True,
transform=train_transform,
)
valid_dataset = dataset_precip.precipitation_maps_h5(
in_file=data_dir,
num_input_images=num_input_images,
num_output_images=num_output_images,
train=True,
transform=valid_transform,
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
)
return train_loader, valid_loader
def get_test_loader(
data_dir,
batch_size,
num_input_images,
num_output_images,
classification,
shuffle=False,
num_workers=4,
pin_memory=False,
):
"""
Utility function for loading and returning a multi-process
test iterator over the CIFAR-10 dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- shuffle: whether to shuffle the dataset after every epoch.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: test set iterator.
"""
# Since I am not dealing with RGB images I do not need this
# normalize = transforms.Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225],
# )
# define transform
transform = None
# transform = transforms.Compose([
# transforms.ToTensor(),
# # normalize,
# ])
if classification:
dataset = dataset_precip.precipitation_maps_classification_h5(
in_file=data_dir,
num_input_images=num_input_images,
img_to_predict=num_output_images,
train=False,
transform=transform,
)
else:
dataset = dataset_precip.precipitation_maps_h5(
in_file=data_dir,
num_input_images=num_input_images,
num_output_images=num_output_images,
train=False,
transform=transform,
)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
)
return data_loader
if __name__ == "__main__":
folder = ROOT_DIR / "data" / "precipitation"
data = "RAD_NL25_RAC_5min_train_test_2016-2019.h5"
train_dl, valid_dl = get_train_valid_loader(
folder / data,
batch_size=8,
random_seed=1337,
num_input_images=12,
num_output_images=6,
classification=True,
augment=False,
valid_size=0.1,
shuffle=True,
num_workers=4,
pin_memory=False,
)
for xb, yb in train_dl:
print("xb.shape: ", xb.shape)
print("yb.shape: ", yb.shape)
break
================================================
FILE: utils/dataset_VOC.py
================================================
import random
import torch
from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
from torchvision import transforms
import torchvision.transforms.functional as TF
import numpy as np
import matplotlib.pyplot as plt
def get_pascal_labels():
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (21, 3)
"""
return np.asarray(
[
[0, 0, 0],
[128, 0, 0],
[0, 128, 0],
[128, 128, 0],
[0, 0, 128],
[128, 0, 128],
[0, 128, 128],
[128, 128, 128],
[64, 0, 0],
[192, 0, 0],
[64, 128, 0],
[192, 128, 0],
[64, 0, 128],
[192, 0, 128],
[64, 128, 128],
[192, 128, 128],
[0, 64, 0],
[128, 64, 0],
[0, 192, 0],
[128, 192, 0],
[0, 64, 128],
]
)
def decode_segmap(label_mask, plot=False):
"""Decode segmentation class labels into a color image
Args:
label_mask (np.ndarray): an (M,N) array of integer values denoting
the class label at each spatial location.
plot (bool, optional): whether to show the resulting color image
in a figure.
Returns:
(np.ndarray, optional): the resulting decoded color image.
"""
label_colours = get_pascal_labels()
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(21):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
return rgb
class VOCSegmentation(Dataset):
CLASS_NAMES = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"potted-plant",
"sheep",
"sofa",
"train",
"tv/monitor",
]
def __init__(self, root: Path, image_set="train", transformations=None, augmentations=False):
super().__init__()
assert image_set in ["train", "val", "trainval"]
voc_root = root / "VOC2012"
image_dir = voc_root / "JPEGImages"
mask_dir = voc_root / "SegmentationClass"
splits_dir = voc_root / "ImageSets" / "Segmentation"
split_f = splits_dir / (image_set.rstrip("\n") + ".txt")
with open(split_f) as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [image_dir / (x + ".jpg") for x in file_names]
self.masks = [mask_dir / (x + ".png") for x in file_names]
assert len(self.images) == len(self.masks)
self.transformations = transformations
self.augmentations = augmentations
def __getitem__(self, index):
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transformations is not None:
img = self.transformations(img)
target = self.transformations(target)
# Apply augmentations to both input and target
if self.augmentations:
img, target = self.apply_augmentations(img, target)
# Convert the RGB image to a tensor
toTensorTransform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
img = toTensorTransform(img)
# Convert target to long tensor
target = torch.from_numpy(np.array(target)).long()
target[target == 255] = 0
return img, target
def __len__(self):
return len(self.images)
def apply_augmentations(self, img, target):
# Horizontal flip
if random.random() > 0.5:
img = TF.hflip(img)
target = TF.hflip(target)
# Random Rotation (clockwise and counterclockwise)
if random.random() > 0.5:
degrees = 10
if random.random() > 0.5:
degrees *= -1
img = TF.rotate(img, degrees)
target = TF.rotate(target, degrees)
# Brighten or darken image (only applied to input image)
if random.random() > 0.5:
brightness = 1.2
if random.random() > 0.5:
brightness -= 0.4
img = TF.adjust_brightness(img, brightness)
return img, target
================================================
FILE: utils/dataset_precip.py
================================================
from torch.utils.data import Dataset
import h5py
import numpy as np
class precipitation_maps_h5(Dataset):
def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None):
super().__init__()
self.file_name = in_file
self.n_images, self.nx, self.ny = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape
self.num_input = num_input_images
self.num_output = num_output_images
self.sequence_length = num_input_images + num_output_images
self.train = train
# Dataset is all the images
self.size_dataset = self.n_images - (num_input_images + num_output_images)
# self.size_dataset = int(self.n_images/(num_input_images+num_output_images))
self.transform = transform
self.dataset = None
def __getitem__(self, index):
# min_feature_range = 0.0
# max_feature_range = 1.0
# with h5py.File(self.file_name, 'r') as dataFile:
# dataset = dataFile["train" if self.train else "test"]['images'][index:index+self.sequence_length]
# load the file here (load as singleton)
if self.dataset is None:
self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][
"images"
]
imgs = np.array(self.dataset[index : index + self.sequence_length], dtype="float32")
# add transforms
if self.transform is not None:
imgs = self.transform(imgs)
input_img = imgs[: self.num_input]
target_img = imgs[-1]
return input_img, target_img
def __len__(self):
return self.size_dataset
class precipitation_maps_oversampled_h5(Dataset):
def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None):
super().__init__()
self.file_name = in_file
self.samples, _, _, _ = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape
self.num_input = num_input_images
self.num_output = num_output_images
self.train = train
# self.size_dataset = int(self.n_images/(num_input_images+num_output_images))
self.transform = transform
self.dataset = None
def __getitem__(self, index):
# load the file here (load as singleton)
if self.dataset is None:
self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][
"images"
]
imgs = np.array(self.dataset[index], dtype="float32")
# add transforms
if self.transform is not None:
imgs = self.transform(imgs)
input_img = imgs[: self.num_input]
target_img = imgs[-1]
return input_img, target_img
def __len__(self):
return self.samples
class precipitation_maps_classification_h5(Dataset):
def __init__(self, in_file, num_input_images, img_to_predict, train=True, transform=None):
super().__init__()
self.file_name = in_file
self.n_images, self.nx, self.ny = h5py.File(self.file_name, "r")["train" if train else "test"]["images"].shape
self.num_input = num_input_images
self.img_to_predict = img_to_predict
self.sequence_length = num_input_images + img_to_predict
self.bins = np.array([0.0, 0.5, 1, 2, 5, 10, 30])
self.train = train
# Dataset is all the images
self.size_dataset = self.n_images - (num_input_images + img_to_predict)
# self.size_dataset = int(self.n_images/(num_input_images+num_output_images))
self.transform = transform
self.dataset = None
def __getitem__(self, index):
# min_feature_range = 0.0
# max_feature_range = 1.0
# with h5py.File(self.file_name, 'r') as dataFile:
# dataset = dataFile["train" if self.train else "test"]['images'][index:index+self.sequence_length]
# load the file here (load as singleton)
if self.dataset is None:
self.dataset = h5py.File(self.file_name, "r", rdcc_nbytes=1024**3)["train" if self.train else "test"][
"images"
]
imgs = np.array(self.dataset[index : index + self.sequence_length], dtype="float32")
# add transforms
if self.transform is not None:
imgs = self.transform(imgs)
input_img = imgs[: self.num_input]
# put the img in buckets
target_img = imgs[-1]
# target_img is normalized by dividing through the highest value of the training set. We reverse this.
# Then target_img is in mm/5min. The bins have the unit mm/hour. Therefore we multiply the img by 12
buckets = np.digitize(target_img * 47.83 * 12, self.bins, right=True)
return input_img, buckets
def __len__(self):
return self.size_dataset
================================================
FILE: utils/formatting.py
================================================
import torch
import numpy as np
def make_metrics_str(metrics_dict):
return " | ".join([f"{name}: {value.item() if isinstance(value, torch.Tensor) else value:.4f}"
for name, value in metrics_dict.items()
if not (isinstance(value, torch.Tensor) and torch.isnan(value)) and
not (not isinstance(value, torch.Tensor) and np.isnan(value))])
================================================
FILE: utils/model_classes.py
================================================
from models import unet_precip_regression_lightning as unet_regr
import lightning.pytorch as pl
def get_model_class(model_file) -> tuple[type[pl.LightningModule], str]:
# This is for some nice plotting
if "UNetAttention" in model_file:
model_name = "UNet Attention"
model = unet_regr.UNetAttention
elif "UNetDSAttention4kpl" in model_file:
model_name = "UNetDS Attention with 4kpl"
model = unet_regr.UNetDSAttention
elif "UNetDSAttention1kpl" in model_file:
model_name = "UNetDS Attention with 1kpl"
model = unet_regr.UNetDSAttention
elif "UNetDSAttention4CBAMs" in model_file:
model_name = "UNetDS Attention 4CBAMs"
model = unet_regr.UNetDSAttention4CBAMs
elif "UNetDSAttention" in model_file:
model_name = "SmaAt-UNet"
model = unet_regr.UNetDSAttention
elif "UNetDS" in model_file:
model_name = "UNetDS"
model = unet_regr.UNetDS
elif "UNet" in model_file:
model_name = "UNet"
model = unet_regr.UNet
elif "PersistenceModel" in model_file:
model_name = "PersistenceModel"
model = unet_regr.PersistenceModel
else:
raise NotImplementedError("Model not found")
return model, model_name
gitextract_3pkgi8mx/
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── README.md
├── calc_metrics_test_set.py
├── create_datasets.py
├── metric/
│ ├── __init__.py
│ ├── confusionmatrix.py
│ ├── iou.py
│ ├── metric.py
│ └── precipitation_metrics.py
├── models/
│ ├── SmaAt_UNet.py
│ ├── __init__.py
│ ├── layers.py
│ ├── regression_lightning.py
│ ├── unet_parts.py
│ ├── unet_parts_depthwise_separable.py
│ └── unet_precip_regression_lightning.py
├── plot_examples.ipynb
├── pyproject.toml
├── requirements-dev.txt
├── requirements-full.txt
├── requirements.txt
├── root.py
├── train_SmaAtUNet.py
├── train_precip_lightning.py
└── utils/
├── __init__.py
├── data_loader_precip.py
├── dataset_VOC.py
├── dataset_precip.py
├── formatting.py
└── model_classes.py
SYMBOL INDEX (140 symbols across 19 files)
FILE: calc_metrics_test_set.py
function parse_args (line 19) | def parse_args():
function convert_tensors_to_python (line 38) | def convert_tensors_to_python(obj):
function save_metrics (line 51) | def save_metrics(results, file_path, file_type="json"):
function run_experiments (line 75) | def run_experiments(model_folder, data_file, threshold=0.5, denormalize=...
function plot_metrics (line 124) | def plot_metrics(results, save_path=""):
function load_metrics_from_file (line 164) | def load_metrics_from_file(file_path, file_type):
function main (line 202) | def main():
FILE: create_datasets.py
function create_dataset (line 8) | def create_dataset(input_length: int, image_ahead: int, rain_amount_thre...
FILE: metric/confusionmatrix.py
class ConfusionMatrix (line 6) | class ConfusionMatrix(metric.Metric):
method __init__ (line 16) | def __init__(self, num_classes, normalized=False):
method reset (line 24) | def reset(self):
method add (line 27) | def add(self, predicted, target):
method value (line 73) | def value(self):
FILE: metric/iou.py
class IoU (line 7) | class IoU(metric.Metric):
method __init__ (line 22) | def __init__(self, num_classes, normalized=False, ignore_index=None):
method reset (line 36) | def reset(self):
method add (line 39) | def add(self, predicted, target):
method value (line 64) | def value(self):
FILE: metric/metric.py
class Metric (line 4) | class Metric(Protocol):
method reset (line 7) | def reset(self):
method add (line 10) | def add(self, predicted, target):
method value (line 13) | def value(self):
FILE: metric/precipitation_metrics.py
class PrecipitationMetrics (line 6) | class PrecipitationMetrics(Metric):
method __init__ (line 12) | def __init__(self, threshold=0.5, denormalize=True, dist_sync_on_step=...
method update (line 37) | def update(self, preds, target):
method compute (line 97) | def compute(self):
FILE: models/SmaAt_UNet.py
class SmaAt_UNet (line 7) | class SmaAt_UNet(nn.Module):
method __init__ (line 8) | def __init__(
method forward (line 41) | def forward(self, x):
FILE: models/layers.py
class DepthToSpace (line 7) | class DepthToSpace(nn.Module):
method __init__ (line 8) | def __init__(self, block_size):
method forward (line 12) | def forward(self, x):
class SpaceToDepth (line 20) | class SpaceToDepth(nn.Module):
method __init__ (line 22) | def __init__(self, block_size):
method forward (line 26) | def forward(self, x):
class DepthwiseSeparableConv (line 34) | class DepthwiseSeparableConv(nn.Module):
method __init__ (line 35) | def __init__(self, in_channels, output_channels, kernel_size, padding=...
method forward (line 47) | def forward(self, x):
class DoubleDense (line 53) | class DoubleDense(nn.Module):
method __init__ (line 54) | def __init__(self, in_channels, hidden_neurons, output_channels):
method forward (line 60) | def forward(self, x):
class DoubleDSConv (line 67) | class DoubleDSConv(nn.Module):
method __init__ (line 70) | def __init__(self, in_channels, out_channels):
method forward (line 81) | def forward(self, x):
class Flatten (line 85) | class Flatten(nn.Module):
method forward (line 86) | def forward(self, x):
class ChannelAttention (line 90) | class ChannelAttention(nn.Module):
method __init__ (line 91) | def __init__(self, input_channels, reduction_ratio=16):
method forward (line 105) | def forward(self, x):
class SpatialAttention (line 114) | class SpatialAttention(nn.Module):
method __init__ (line 115) | def __init__(self, kernel_size=7):
method forward (line 122) | def forward(self, x):
class CBAM (line 132) | class CBAM(nn.Module):
method __init__ (line 133) | def __init__(self, input_channels, reduction_ratio=16, kernel_size=7):
method forward (line 138) | def forward(self, x):
FILE: models/regression_lightning.py
class UNetBase (line 12) | class UNetBase(pl.LightningModule):
method add_model_specific_args (line 14) | def add_model_specific_args(parent_parser):
method __init__ (line 31) | def __init__(self, hparams):
method forward (line 44) | def forward(self, x):
method configure_optimizers (line 47) | def configure_optimizers(self):
method loss_func (line 57) | def loss_func(self, y_pred, y_true):
method training_step (line 67) | def training_step(self, batch, batch_idx):
method validation_step (line 79) | def validation_step(self, batch, batch_idx):
method test_step (line 88) | def test_step(self, batch, batch_idx):
method on_test_epoch_end (line 96) | def on_test_epoch_end(self):
method on_train_epoch_end (line 106) | def on_train_epoch_end(self):
method on_validation_epoch_end (line 113) | def on_validation_epoch_end(self):
class PrecipRegressionBase (line 121) | class PrecipRegressionBase(UNetBase):
method add_model_specific_args (line 123) | def add_model_specific_args(parent_parser):
method __init__ (line 134) | def __init__(self, hparams):
method prepare_data (line 141) | def prepare_data(self):
method train_dataloader (line 177) | def train_dataloader(self):
method val_dataloader (line 189) | def val_dataloader(self):
class PersistenceModel (line 202) | class PersistenceModel(UNetBase):
method forward (line 203) | def forward(self, x):
FILE: models/unet_parts.py
class DoubleConv (line 8) | class DoubleConv(nn.Module):
method __init__ (line 11) | def __init__(self, in_channels, out_channels, mid_channels=None):
method forward (line 24) | def forward(self, x):
class Down (line 28) | class Down(nn.Module):
method __init__ (line 31) | def __init__(self, in_channels, out_channels):
method forward (line 35) | def forward(self, x):
class Up (line 39) | class Up(nn.Module):
method __init__ (line 42) | def __init__(self, in_channels, out_channels, bilinear=True):
method forward (line 53) | def forward(self, x1, x2):
class OutConv (line 67) | class OutConv(nn.Module):
method __init__ (line 68) | def __init__(self, in_channels, out_channels):
method forward (line 72) | def forward(self, x):
FILE: models/unet_parts_depthwise_separable.py
class DoubleConvDS (line 10) | class DoubleConvDS(nn.Module):
method __init__ (line 13) | def __init__(self, in_channels, out_channels, mid_channels=None, kerne...
method forward (line 38) | def forward(self, x):
class DownDS (line 42) | class DownDS(nn.Module):
method __init__ (line 45) | def __init__(self, in_channels, out_channels, kernels_per_layer=1):
method forward (line 52) | def forward(self, x):
class UpDS (line 56) | class UpDS(nn.Module):
method __init__ (line 59) | def __init__(self, in_channels, out_channels, bilinear=True, kernels_p...
method forward (line 75) | def forward(self, x1, x2):
class OutConv (line 89) | class OutConv(nn.Module):
method __init__ (line 90) | def __init__(self, in_channels, out_channels):
method forward (line 94) | def forward(self, x):
FILE: models/unet_precip_regression_lightning.py
class UNet (line 7) | class UNet(PrecipRegressionBase):
method __init__ (line 8) | def __init__(self, hparams):
method forward (line 27) | def forward(self, x):
class UNetAttention (line 41) | class UNetAttention(PrecipRegressionBase):
method __init__ (line 42) | def __init__(self, hparams):
method forward (line 67) | def forward(self, x):
class UNetDS (line 86) | class UNetDS(PrecipRegressionBase):
method __init__ (line 87) | def __init__(self, hparams):
method forward (line 107) | def forward(self, x):
class UNetDSAttention (line 121) | class UNetDSAttention(PrecipRegressionBase):
method __init__ (line 122) | def __init__(self, hparams):
method forward (line 148) | def forward(self, x):
class UNetDSAttention4CBAMs (line 167) | class UNetDSAttention4CBAMs(PrecipRegressionBase):
method __init__ (line 168) | def __init__(self, hparams):
method forward (line 193) | def forward(self, x):
FILE: train_SmaAtUNet.py
function get_lr (line 18) | def get_lr(optimizer):
function fit (line 23) | def fit(
FILE: train_precip_lightning.py
function train_regression (line 15) | def train_regression(hparams, find_batch_size_automatically: bool = False):
FILE: utils/data_loader_precip.py
function get_train_valid_loader (line 10) | def get_train_valid_loader(
function get_test_loader (line 141) | def get_test_loader(
FILE: utils/dataset_VOC.py
function get_pascal_labels (line 12) | def get_pascal_labels():
function decode_segmap (line 44) | def decode_segmap(label_mask, plot=False):
class VOCSegmentation (line 73) | class VOCSegmentation(Dataset):
method __init__ (line 98) | def __init__(self, root: Path, image_set="train", transformations=None...
method __getitem__ (line 119) | def __getitem__(self, index):
method __len__ (line 147) | def __len__(self):
method apply_augmentations (line 150) | def apply_augmentations(self, img, target):
FILE: utils/dataset_precip.py
class precipitation_maps_h5 (line 6) | class precipitation_maps_h5(Dataset):
method __init__ (line 7) | def __init__(self, in_file, num_input_images, num_output_images, train...
method __getitem__ (line 24) | def __getitem__(self, index):
method __len__ (line 44) | def __len__(self):
class precipitation_maps_oversampled_h5 (line 48) | class precipitation_maps_oversampled_h5(Dataset):
method __init__ (line 49) | def __init__(self, in_file, num_input_images, num_output_images, train...
method __getitem__ (line 63) | def __getitem__(self, index):
method __len__ (line 79) | def __len__(self):
class precipitation_maps_classification_h5 (line 83) | class precipitation_maps_classification_h5(Dataset):
method __init__ (line 84) | def __init__(self, in_file, num_input_images, img_to_predict, train=Tr...
method __getitem__ (line 102) | def __getitem__(self, index):
method __len__ (line 126) | def __len__(self):
FILE: utils/formatting.py
function make_metrics_str (line 5) | def make_metrics_str(metrics_dict):
FILE: utils/model_classes.py
function get_model_class (line 5) | def get_model_class(model_file) -> tuple[type[pl.LightningModule], str]:
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (117K chars).
[
{
"path": ".gitignore",
"chars": 88,
"preview": ".idea\n.vscode\ndata\nruns\nlightning\ncheckpoints\nplots\n.ruff_cache\n.mypy_cache\n__pycache__\n"
},
{
"path": ".pre-commit-config.yaml",
"chars": 404,
"preview": "repos:\n- repo: https://github.com/pre-commit/pre-commit-hooks\n rev: v4.5.0\n hooks:\n - id: check-yaml\n - "
},
{
"path": ".python-version",
"chars": 8,
"preview": "3.10.11\n"
},
{
"path": "README.md",
"chars": 7347,
"preview": "# SmaAt-UNet\nCode for the Paper \"SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture\" [Arxiv-"
},
{
"path": "calc_metrics_test_set.py",
"chars": 9178,
"preview": "import argparse\nfrom argparse import Namespace\nimport json\nimport csv\nimport numpy as np\nimport os\nfrom pprint import pp"
},
{
"path": "create_datasets.py",
"chars": 4535,
"preview": "import h5py\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom root import ROOT_DIR\n\n\ndef create_dataset(input_length: int, "
},
{
"path": "metric/__init__.py",
"chars": 141,
"preview": "from .confusionmatrix import ConfusionMatrix\nfrom .iou import IoU\nfrom .metric import Metric\n\n__all__ = [\"ConfusionMatri"
},
{
"path": "metric/confusionmatrix.py",
"chars": 3553,
"preview": "import numpy as np\nimport torch\nfrom metric import metric\n\n\nclass ConfusionMatrix(metric.Metric):\n \"\"\"Constructs a co"
},
{
"path": "metric/iou.py",
"chars": 3717,
"preview": "import numpy as np\nfrom metric import metric\nfrom metric.confusionmatrix import ConfusionMatrix\n\n\n# Taken from https://g"
},
{
"path": "metric/metric.py",
"chars": 214,
"preview": "from typing import Protocol\n\n\nclass Metric(Protocol):\n \"\"\"Base class for all metrics.\"\"\"\n\n def reset(self):\n "
},
{
"path": "metric/precipitation_metrics.py",
"chars": 6276,
"preview": "import torch\nfrom torchmetrics import Metric\nimport numpy as np\n\n\nclass PrecipitationMetrics(Metric):\n \"\"\"\n A cust"
},
{
"path": "models/SmaAt_UNet.py",
"chars": 2272,
"preview": "from torch import nn\nfrom models.unet_parts import OutConv\nfrom models.unet_parts_depthwise_separable import DoubleConvD"
},
{
"path": "models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/layers.py",
"chars": 5146,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\n# Taken from https://discuss.pytorch.org/t/is-there-"
},
{
"path": "models/regression_lightning.py",
"chars": 7856,
"preview": "import lightning.pytorch as pl\nimport torch\nfrom torch import nn, optim, Tensor\nfrom torch.utils.data import DataLoader\n"
},
{
"path": "models/unet_parts.py",
"chars": 2508,
"preview": "\"\"\" Parts of the U-Net model \"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DoubleConv("
},
{
"path": "models/unet_parts_depthwise_separable.py",
"chars": 3213,
"preview": "\"\"\" Parts of the U-Net model \"\"\"\n\n# Base model taken from: https://github.com/milesial/Pytorch-UNet\nimport torch\nimport "
},
{
"path": "models/unet_precip_regression_lightning.py",
"chars": 8572,
"preview": "from models.unet_parts import Down, DoubleConv, Up, OutConv\nfrom models.unet_parts_depthwise_separable import DoubleConv"
},
{
"path": "plot_examples.ipynb",
"chars": 5999,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": null,\n \"metadata\": {\n \"collapsed\": true\n },\n \"o"
},
{
"path": "pyproject.toml",
"chars": 2709,
"preview": "[project]\nname = \"smaat-unet\"\nversion = \"0.1.0\"\ndescription = \"Code for the paper `SmaAt-UNet: Precipitation Nowcasting "
},
{
"path": "requirements-dev.txt",
"chars": 2508,
"preview": "# This file was autogenerated by uv via the following command:\n# uv export --format=requirements-txt --group dev --no"
},
{
"path": "requirements-full.txt",
"chars": 3311,
"preview": "# This file was autogenerated by uv via the following command:\n# uv export --format=requirements-txt --all-groups --n"
},
{
"path": "requirements.txt",
"chars": 2293,
"preview": "# This file was autogenerated by uv via the following command:\n# uv export --format=requirements-txt --no-dev --no-ha"
},
{
"path": "root.py",
"chars": 59,
"preview": "from pathlib import Path\n\nROOT_DIR = Path(__file__).parent\n"
},
{
"path": "train_SmaAtUNet.py",
"chars": 6396,
"preview": "from typing import Optional\n\nfrom models.SmaAt_UNet import SmaAt_UNet\nimport torch\nfrom torch.utils.data import DataLoad"
},
{
"path": "train_precip_lightning.py",
"chars": 4214,
"preview": "from root import ROOT_DIR\n\nimport lightning.pytorch as pl\nfrom lightning.pytorch.callbacks import (\n ModelCheckpoint,"
},
{
"path": "utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "utils/data_loader_precip.py",
"chars": 6975,
"preview": "import torch\nimport numpy as np\nfrom torchvision import transforms\nfrom torch.utils.data.sampler import SubsetRandomSamp"
},
{
"path": "utils/dataset_VOC.py",
"chars": 4963,
"preview": "import random\nimport torch\nfrom torch.utils.data import Dataset\nfrom PIL import Image\nfrom pathlib import Path\nfrom torc"
},
{
"path": "utils/dataset_precip.py",
"chars": 4936,
"preview": "from torch.utils.data import Dataset\nimport h5py\nimport numpy as np\n\n\nclass precipitation_maps_h5(Dataset):\n def __in"
},
{
"path": "utils/formatting.py",
"chars": 434,
"preview": "import torch\nimport numpy as np\n\n\ndef make_metrics_str(metrics_dict):\n return \" | \".join([f\"{name}: {value.item() if "
},
{
"path": "utils/model_classes.py",
"chars": 1265,
"preview": "from models import unet_precip_regression_lightning as unet_regr\nimport lightning.pytorch as pl\n\n\ndef get_model_class(mo"
}
]
About this extraction
This page contains the full source code of the HansBambel/SmaAt-UNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (108.5 KB), approximately 29.7k tokens, and a symbol index with 140 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.