[
  {
    "path": ".gitignore",
    "content": ".idea\n.vscode\ndata\nruns\nlightning\ncheckpoints\nplots\n.ruff_cache\n.mypy_cache\n__pycache__\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n    -   id: check-yaml\n    -   id: end-of-file-fixer\n    -   id: trailing-whitespace\n- repo: https://github.com/charliermarsh/ruff-pre-commit\n  rev: 'v0.9.7'\n  hooks:\n    - id: ruff\n      args: [\"--fix\"]\n    - id: ruff-format\n- repo: https://github.com/pre-commit/mirrors-mypy\n  rev: v1.15.0\n  hooks:\n  - id: mypy"
  },
  {
    "path": ".python-version",
    "content": "3.10.11\n"
  },
  {
    "path": "README.md",
    "content": "# SmaAt-UNet\nCode for the Paper \"SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture\" [Arxiv-link](https://arxiv.org/abs/2007.04417), [Elsevier-link](https://www.sciencedirect.com/science/article/pii/S0167865521000556?via%3Dihub)\n\nOther relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting). \n\n\n![SmaAt-UNet](SmaAt-UNet.png)\n\nThe proposed SmaAt-UNet can be found in the model-folder under [SmaAt_UNet](models/SmaAt_UNet.py).\n\n**>>>IMPORTANT<<<**\n\nThe original Code from the paper can be found in this branch: https://github.com/HansBambel/SmaAt-UNet/tree/snapshot-paper\n\nThe current master branch has since upgraded packages and was refactored. Since the exact package-versions differ the experiments may not be 100% reproducible.\n\nIf you have problems running the code, feel free to open an issue here on Github.\n\n## Installing dependencies\nThis project is using [uv](https://docs.astral.sh/uv/) as dependency management. Therefore, installing the required dependencies is as easy as this:\n```shell\nuv sync --frozen\n```\n\nIf you have Nvidia GPU(s) that support CUDA, and running the command above have not installed the torch with CUDA support, you can install the relevant dependencies with the following command:\n```shell\nuv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n```\n\nIn any case a [`requirements.txt`](requirements.txt) file are generated from `uv.lock`, which implies that `uv` must be installed in order to perform the export:\n```shell\nuv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt\n```\n\nA `requirements-dev.txt` file is also generated for development dependencies:\n```shell\nuv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt\n```\n\nA `requirements-full.txt` file is also generated for all dependencies (including development and additional dependencies):\n```shell\nuv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt\n```\n\nAlternatively, users can use the existing [`requirements.txt`](requirements.txt), [`requirements-dev.txt`](requirements-dev.txt) and [`requirements-full.txt`](requirements-full.txt) files.\n\nBasically, only the following requirements are needed for the training:\n```\ntqdm\ntorch\nlightning\nmatplotlib\ntensorboard\ntorchsummary\nh5py\nnumpy\n```\n\nAdditional requirements are needed for the [testing](###Testing):\n```\nipykernel\n```\n\n---\nFor the paper we used the [Lightning](https://github.com/Lightning-AI/lightning) -module (PL) which simplifies the training process and allows easy additions of loggers and checkpoint creations.\nIn order to use PL we created the model [UNetDSAttention](models/unet_precip_regression_lightning.py) whose parent inherits from the pl.LightningModule. This model is the same as the pure PyTorch SmaAt-UNet implementation with the added PL functions.\n\n### Training\nAn example [training script](train_SmaAtUNet.py) is given for a classification task (PascalVOC).\n\nFor training on the precipitation task we used the [train_precip_lightning.py](train_precip_lightning.py) file.\nThe training will place a checkpoint file for every model in the `default_save_path` `lightning/precip_regression`.\nAfter finishing training place the best models (probably the ones with the lowest validation loss) that you want to compare in another folder in `checkpoints/comparison`.\n\n### Testing\nThe script [calc_metrics_test_set.py](calc_metrics_test_set.py) will use all the models placed in `checkpoints/comparison` to calculate the MSEs and other metrics such as Precision, Recall, Accuracy, F1, CSI, FAR, HSS.\nThe results will get saved in a json in the same folder as the models.\n\nThe metrics of the persistence model (for now) are only calculated using the script [test_precip_lightning.py](test_precip_lightning.py). This script will also use all models in that folder and calculate the test-losses for the models in addition to the persistence model.\nIt will get handled by [this issue](https://github.com/HansBambel/SmaAt-UNet/issues/28).\n\n### Plots\nExample code for creating similar plots as in the paper can be found in [plot_examples.ipynb](plot_examples.ipynb).\n\nYou might need to make your kernel discoverable so that Jupyter can detect it. To do so, run the following command in your terminal:\n\n```shell\nuv run ipython kernel install --user --env VIRTUAL_ENV=$(pwd)/.venv --name=smaat_unet --diplay-name=\"SmaAt-UNet Kernel\"\n```\n\n### Precipitation dataset\nThe dataset consists of precipitation maps in 5-minute intervals from 2016-2019 resulting in about 420,000 images.\n\nThe dataset is based on radar precipitation maps from the [The Royal Netherlands Meteorological Institute (KNMI)](https://www.knmi.nl/over-het-knmi/about).\nThe original images were cropped as can be seen in the example below:\n![Precip cutout](Precipitation%20map%20Cutout.png)\n\nIf you are interested in the dataset that we used please write an e-mail to: k.trebing@alumni.maastrichtuniversity.nl and s.mehrkanoon@uu.nl\n\nThe 50% dataset has 4GB in size and the 20% dataset has 16.5GB in size. Use the [create_dataset.py](create_datasets.py) to create the two datasets used from the original dataset.\n\nThe dataset is already normalized using a [Min-Max normalization](https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization)).\nIn order to revert this you need to multiply the images by 47.83; this results in the images showing the mm/5min.\n\n### Citation\n```\n@article{TREBING2021,\ntitle = {SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture},\njournal = {Pattern Recognition Letters},\nyear = {2021},\nissn = {0167-8655},\ndoi = {https://doi.org/10.1016/j.patrec.2021.01.036},\nurl = {https://www.sciencedirect.com/science/article/pii/S0167865521000556},\nauthor = {Kevin Trebing and Tomasz Staǹczyk and Siamak Mehrkanoon},\nkeywords = {Domain adaptation, neural networks, kernel methods, coupling regularization},\nabstract = {Weather forecasting is dominated by numerical weather prediction that tries to model accurately the physical properties of the atmosphere. A downside of numerical weather prediction is that it is lacking the ability for short-term forecasts using the latest available information. By using a data-driven neural network approach we show that it is possible to produce an accurate precipitation nowcast. To this end, we propose SmaAt-UNet, an efficient convolutional neural networks-based on the well known UNet architecture equipped with attention modules and depthwise-separable convolutions. We evaluate our approaches on a real-life datasets using precipitation maps from the region of the Netherlands and binary images of cloud coverage of France. The experimental results show that in terms of prediction performance, the proposed model is comparable to other examined models while only using a quarter of the trainable parameters.}\n}\n\n```\n\n### Other relevant publications\n\nYour project might also benefit from exploring our other relevant publications: [Weather Elements Nowcasting](https://sites.google.com/view/siamak-mehrkanoon/projects/weather-elements-nowcasting). Please consider citing them if you find them useful.\n\n"
  },
  {
    "path": "calc_metrics_test_set.py",
    "content": "import argparse\nfrom argparse import Namespace\nimport json\nimport csv\nimport numpy as np\nimport os\nfrom pprint import pprint\nimport torch\nfrom tqdm import tqdm\nimport lightning.pytorch as pl\nimport matplotlib.pyplot as plt\n\nfrom metric.precipitation_metrics import PrecipitationMetrics\nfrom models.unet_precip_regression_lightning import PersistenceModel\nfrom root import ROOT_DIR\nfrom utils import dataset_precip, model_classes\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Calculate metrics for precipitation models\")\n\n    parser.add_argument(\"--model-folder\", type=str, default=None,\n                        help=\"Path to model folder (default: ROOT_DIR/checkpoints/comparison)\")\n    parser.add_argument(\"--threshold\", type=float, default=0.5, help=\"Precipitation threshold)\")\n    parser.add_argument(\"--normalized\", action=\"store_false\", dest=\"denormalize\", default=True,\n                        help=\"Calculate metrics for normalized data\")\n    parser.add_argument(\"--load-metrics\", action=\"store_true\",\n                        help=\"Load existing metrics instead of running experiments\")\n    parser.add_argument(\"--file-type\", type=str, default=\"txt\", choices=[\"json\", \"txt\", \"csv\"],\n                        help=\"Output file format (json, txt, or csv)\")\n    parser.add_argument(\"--save-dir\", type=str, default=None,\n                        help=\"Path to save the metrics\")\n    parser.add_argument(\"--plot\", action=\"store_true\", help=\"Plot the metrics\")\n    \n    return parser.parse_args()\n\n\ndef convert_tensors_to_python(obj):\n    \"\"\"Convert PyTorch tensors to Python native types recursively.\"\"\"\n    if isinstance(obj, torch.Tensor):\n        # Convert tensor to float/int, handling nan values\n        value = obj.item() if obj.numel() == 1 else obj.tolist()\n        return float('nan') if isinstance(value, float) and np.isnan(value) else value\n    elif isinstance(obj, dict):\n        return {key: convert_tensors_to_python(value) for key, value in obj.items()}\n    elif isinstance(obj, list):\n        return [convert_tensors_to_python(value) for value in obj]\n    return obj\n\n\ndef save_metrics(results, file_path, file_type=\"json\"):\n    \"\"\"\n    Save metrics to a file in the specified format.\n    \n    Args:\n        results (dict): Dictionary containing model metrics\n        file_path (Path): Path to save the file\n        file_type (str): File format - \"json\", \"txt\", or \"csv\"\n    \"\"\"\n    with open(file_path, \"w\") as f:\n        if file_type == \"csv\":\n            writer = csv.writer(f)\n            # Write header row\n            if results:\n                first_model = next(iter(results.values()))\n                writer.writerow([\"Model\"] + list(first_model.keys()))\n                # Write data rows\n                for model_name, metrics in results.items():\n                    writer.writerow([model_name] + list(metrics.values()))\n        else:\n            # For json and txt formats, use json.dump\n            json.dump(results, f, indent=4)\n\n\ndef run_experiments(model_folder, data_file, threshold=0.5, denormalize=True):\n    \"\"\"\n    Run test experiments for all models in the model folder.\n\n    Args:\n        model_folder (Path): Path to the model folder\n        data_file (Path): Path to the data file\n        threshold (float): Threshold for the precipitation\n\n    Returns:\n        dict: Dictionary containing model metrics\n    \"\"\"\n    results = {}\n\n    # Load the dataset\n    dataset = dataset_precip.precipitation_maps_oversampled_h5(\n        in_file=data_file, num_input_images=12, num_output_images=6, train=False\n    )\n\n    # Create dataloader\n    test_dl = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False, pin_memory=True, persistent_workers=True)\n\n    # Create trainer with device specification\n    trainer = pl.Trainer(logger=False, enable_checkpointing=False)\n\n\n    # Find all models in the model folder\n    models = [\"PersistenceModel\"] + [f for f in os.listdir(model_folder) if f.endswith(\".ckpt\")]\n    if not models:\n        raise ValueError(\"No checkpoint files found in the model folder.\")\n\n    for model_file in tqdm(models, desc=\"Models\", leave=True):\n\n        model, model_name = model_classes.get_model_class(model_file)\n        print(f\"Loading model: {model_name}\")\n\n        if model_file == \"PersistenceModel\":\n            loaded_model = model(Namespace())\n        else:\n            loaded_model = model.load_from_checkpoint(os.path.join(model_folder, model_file))\n        \n        # Override the test metrics, required for testing with normalized data\n        loaded_model.test_metrics = PrecipitationMetrics(threshold=threshold, denormalize=denormalize)\n\n        results[model_name] = trainer.test(model=loaded_model, dataloaders=[test_dl])[0]\n\n    return results\n\n\ndef plot_metrics(results, save_path=\"\"):\n    \"\"\"\n    Plot all metrics from the results dictionary.\n    \n    Args:\n        results (dict): Dictionary containing metrics for different models\n        save_path (str): Path to save the plots\n    \"\"\"\n    model_names = list(results.keys())\n    metrics = list(results[model_names[0]].keys())\n    \n    for metric_name in metrics:\n        # Get values, replacing NaN with None to skip them in the plot\n        values = []\n        valid_names = []\n        \n        for model_name in model_names:\n            val = results[model_name][metric_name]\n            # Skip NaN values\n            if not (isinstance(val, float) and np.isnan(val)):\n                values.append(val)\n                valid_names.append(model_name)\n        \n        if len(valid_names) == 0:\n            continue\n            \n        plt.figure(figsize=(10, 6))\n        plt.bar(valid_names, values)\n        plt.xticks(rotation=45)\n        plt.xlabel(\"Models\")\n        plt.ylabel(f\"{metric_name.upper()}\")\n        plt.title(f\"Comparison of different models - {metric_name.upper()}\")\n        plt.tight_layout()\n        \n        if save_path:\n            plt.savefig(f\"{save_path}/{metric_name}.png\")\n            \n        plt.show()\n\n    \ndef load_metrics_from_file(file_path, file_type):\n    \"\"\"\n    Load metrics from a file in the specified format.\n    \n    Args:\n        file_path (Path): Path to the metrics file\n        file_type (str): File format - \"json\", \"txt\", or \"csv\"\n        \n    Returns:\n        dict: Dictionary containing model metrics\n    \"\"\"\n    results = {}\n    \n    with open(file_path) as f:\n        if file_type == \"json\" or file_type == \"txt\":\n            # Both json and txt files are saved in JSON format in this application\n            results = json.load(f)\n        elif file_type == \"csv\":\n            reader = csv.reader(f)\n            headers = next(reader)  # Get header row\n            for row in reader:\n                if len(row) > 0:\n                    model_name = row[0]\n                    # Create dictionary of metrics for this model\n                    model_metrics = {}\n                    for i in range(1, len(headers)):\n                        if i < len(row):\n                            # Try to convert to float if possible\n                            try:\n                                value = float(row[i])\n                            except (ValueError, TypeError):\n                                value = row[i]\n                            model_metrics[headers[i]] = value\n                    results[model_name] = model_metrics\n    \n    return results\n\n\ndef main():\n    # Parse command line arguments\n    args = parse_args()\n    \n    # Variables from command line arguments\n    load_metrics = args.load_metrics\n    threshold = args.threshold\n    file_type = args.file_type\n    denormalize = args.denormalize\n\n    # Define paths\n    model_folder = ROOT_DIR / \"checkpoints\" / \"comparison\" if args.model_folder is None else args.model_folder\n    data_file = ROOT_DIR / \"data\" / \"precipitation\" / f\"train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_{int(threshold*100)}.h5\"\n    save_dir = model_folder / f\"calc_metrics_test_set_results_{'denorm' if denormalize else 'norm'}\" if args.save_dir is None else args.save_dir\n    os.makedirs(save_dir, exist_ok=True)\n\n    file_name = f\"model_metrics_{threshold}_{'denorm' if denormalize else 'norm'}.{file_type}\"\n    file_path = save_dir / file_name\n\n    # Load metrics if available\n    if load_metrics:\n        print(f\"Loading metrics from {file_path}\")\n        results = load_metrics_from_file(file_path, file_type)\n    else:\n        # Run experiments\n        print(f\"Running experiments with {threshold} threshold and {'denormalized' if denormalize else 'normalized'} data\")\n        results = run_experiments(model_folder, data_file, threshold=threshold, denormalize=denormalize)\n\n        # Convert tensors to Python native types\n        results = convert_tensors_to_python(results)\n\n        save_metrics(results, file_path, file_type)\n\n    # Display results\n    print(f\"Metrics saved to {file_path}\")\n    pprint(results)\n\n    # Plot metrics if requested\n    if args.plot:\n        plots_dir = save_dir / \"plots\"\n        os.makedirs(plots_dir, exist_ok=True)\n\n        plot_metrics(results, plots_dir)\n        print(f\"Plots saved to {plots_dir}\")\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "create_datasets.py",
    "content": "import h5py\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom root import ROOT_DIR\n\n\ndef create_dataset(input_length: int, image_ahead: int, rain_amount_thresh: float):\n    \"\"\"Create a dataset that has target images containing at least `rain_amount_thresh` (percent) of rain.\"\"\"\n\n    precipitation_folder = ROOT_DIR / \"data\" / \"precipitation\"\n    with h5py.File(\n        precipitation_folder / \"RAD_NL25_RAC_5min_train_test_2016-2019.h5\",\n        \"r\",\n    ) as orig_f:\n        train_images = orig_f[\"train\"][\"images\"]\n        train_timestamps = orig_f[\"train\"][\"timestamps\"]\n        test_images = orig_f[\"test\"][\"images\"]\n        test_timestamps = orig_f[\"test\"][\"timestamps\"]\n        print(\"Train shape\", train_images.shape)\n        print(\"Test shape\", test_images.shape)\n        imgSize = train_images.shape[1]\n        num_pixels = imgSize * imgSize\n\n        filename = (\n            precipitation_folder / f\"train_test_2016-2019_input-length_{input_length}_img-\"\n            f\"ahead_{image_ahead}_rain-threshold_{int(rain_amount_thresh * 100)}.h5\"\n        )\n\n        with h5py.File(filename, \"w\") as f:\n            train_set = f.create_group(\"train\")\n            test_set = f.create_group(\"test\")\n            train_image_dataset = train_set.create_dataset(\n                \"images\",\n                shape=(1, input_length + image_ahead, imgSize, imgSize),\n                maxshape=(None, input_length + image_ahead, imgSize, imgSize),\n                dtype=\"float32\",\n                compression=\"gzip\",\n                compression_opts=9,\n            )\n            train_timestamp_dataset = train_set.create_dataset(\n                \"timestamps\",\n                shape=(1, input_length + image_ahead, 1),\n                maxshape=(None, input_length + image_ahead, 1),\n                dtype=h5py.special_dtype(vlen=str),\n                compression=\"gzip\",\n                compression_opts=9,\n            )\n            test_image_dataset = test_set.create_dataset(\n                \"images\",\n                shape=(1, input_length + image_ahead, imgSize, imgSize),\n                maxshape=(None, input_length + image_ahead, imgSize, imgSize),\n                dtype=\"float32\",\n                compression=\"gzip\",\n                compression_opts=9,\n            )\n            test_timestamp_dataset = test_set.create_dataset(\n                \"timestamps\",\n                shape=(1, input_length + image_ahead, 1),\n                maxshape=(None, input_length + image_ahead, 1),\n                dtype=h5py.special_dtype(vlen=str),\n                compression=\"gzip\",\n                compression_opts=9,\n            )\n\n            origin = [[train_images, train_timestamps], [test_images, test_timestamps]]\n            datasets = [\n                [train_image_dataset, train_timestamp_dataset],\n                [test_image_dataset, test_timestamp_dataset],\n            ]\n            for origin_id, (images, timestamps) in enumerate(origin):\n                image_dataset, timestamp_dataset = datasets[origin_id]\n                \n                # Pre-calculate all valid indices that meet the rain threshold\n                sequence_length = input_length + image_ahead\n                valid_indices = []\n                \n                # Use vectorized operations to find valid indices\n                for i in tqdm(range(sequence_length, len(images)), desc=\"Finding valid indices\"):\n                    rain_pixels = np.sum(images[i] > 0)\n                    if rain_pixels >= num_pixels * rain_amount_thresh:\n                        valid_indices.append(i)\n                \n                # Pre-allocate the final dataset size\n                total_sequences = len(valid_indices)\n                image_dataset.resize(total_sequences, axis=0)\n                timestamp_dataset.resize(total_sequences, axis=0)\n                \n                # Batch process the sequences\n                for idx, i in enumerate(tqdm(valid_indices, desc=\"Creating sequences\")):\n                    imgs = images[i - sequence_length : i]\n                    timestamps_img = timestamps[i - sequence_length : i]\n                    image_dataset[idx] = imgs\n                    timestamp_dataset[idx] = timestamps_img\n\n\nif __name__ == \"__main__\":\n    print(\"Creating dataset with at least 20% of rain pixel in target image\")\n    create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.2)\n    print(\"Creating dataset with at least 50% of rain pixel in target image\")\n    create_dataset(input_length=12, image_ahead=6, rain_amount_thresh=0.5)\n"
  },
  {
    "path": "metric/__init__.py",
    "content": "from .confusionmatrix import ConfusionMatrix\nfrom .iou import IoU\nfrom .metric import Metric\n\n__all__ = [\"ConfusionMatrix\", \"IoU\", \"Metric\"]\n"
  },
  {
    "path": "metric/confusionmatrix.py",
    "content": "import numpy as np\nimport torch\nfrom metric import metric\n\n\nclass ConfusionMatrix(metric.Metric):\n    \"\"\"Constructs a confusion matrix for a multi-class classification problems.\n    Does not support multi-label, multi-class problems.\n    Keyword arguments:\n    - num_classes (int): number of classes in the classification problem.\n    - normalized (boolean, optional): Determines whether or not the confusion\n    matrix is normalized or not. Default: False.\n    Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py\n    \"\"\"\n\n    def __init__(self, num_classes, normalized=False):\n        super().__init__()\n\n        self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)\n        self.normalized = normalized\n        self.num_classes = num_classes\n        self.reset()\n\n    def reset(self):\n        self.conf.fill(0)\n\n    def add(self, predicted, target):\n        \"\"\"Computes the confusion matrix\n        The shape of the confusion matrix is K x K, where K is the number\n        of classes.\n        Keyword arguments:\n        - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of\n        predicted scores obtained from the model for N examples and K classes,\n        or an N-tensor/array of integer values between 0 and K-1.\n        - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of\n        ground-truth classes for N examples and K classes, or an N-tensor/array\n        of integer values between 0 and K-1.\n        \"\"\"\n        # If target and/or predicted are tensors, convert them to numpy arrays\n        if torch.is_tensor(predicted):\n            predicted = predicted.cpu().numpy()\n        if torch.is_tensor(target):\n            target = target.cpu().numpy()\n\n        assert predicted.shape[0] == target.shape[0], \"number of targets and predicted outputs do not match\"\n\n        if np.ndim(predicted) != 1:\n            assert (\n                predicted.shape[1] == self.num_classes\n            ), \"number of predictions does not match size of confusion matrix\"\n            predicted = np.argmax(predicted, 1)\n        else:\n            assert (predicted.max() < self.num_classes) and (\n                predicted.min() >= 0\n            ), \"predicted values are not between 0 and k-1\"\n\n        if np.ndim(target) != 1:\n            assert target.shape[1] == self.num_classes, \"Onehot target does not match size of confusion matrix\"\n            assert (target >= 0).all() and (target <= 1).all(), \"in one-hot encoding, target values should be 0 or 1\"\n            assert (target.sum(1) == 1).all(), \"multi-label setting is not supported\"\n            target = np.argmax(target, 1)\n        else:\n            assert (target.max() < self.num_classes) and (target.min() >= 0), \"target values are not between 0 and k-1\"\n\n        # hack for bincounting 2 arrays together\n        x = predicted + self.num_classes * target\n        bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes**2)\n        assert bincount_2d.size == self.num_classes**2\n        conf = bincount_2d.reshape((self.num_classes, self.num_classes))\n\n        self.conf += conf\n\n    def value(self):\n        \"\"\"\n        Returns:\n            Confustion matrix of K rows and K columns, where rows corresponds\n            to ground-truth targets and columns corresponds to predicted\n            targets.\n        \"\"\"\n        if self.normalized:\n            conf = self.conf.astype(np.float32)\n            return conf / conf.sum(1).clip(min=1e-12)[:, None]\n        else:\n            return self.conf\n"
  },
  {
    "path": "metric/iou.py",
    "content": "import numpy as np\nfrom metric import metric\nfrom metric.confusionmatrix import ConfusionMatrix\n\n\n# Taken from https://github.com/davidtvs/PyTorch-ENet/tree/master/metric\nclass IoU(metric.Metric):\n    \"\"\"Computes the intersection over union (IoU) per class and corresponding\n    mean (mIoU).\n    Intersection over union (IoU) is a common evaluation metric for semantic\n    segmentation. The predictions are first accumulated in a confusion matrix\n    and the IoU is computed from it as follows:\n        IoU = true_positive / (true_positive + false_positive + false_negative).\n    Keyword arguments:\n    - num_classes (int): number of classes in the classification problem\n    - normalized (boolean, optional): Determines whether or not the confusion\n    matrix is normalized or not. Default: False.\n    - ignore_index (int or iterable, optional): Index of the classes to ignore\n    when computing the IoU. Can be an int, or any iterable of ints.\n    \"\"\"\n\n    def __init__(self, num_classes, normalized=False, ignore_index=None):\n        super().__init__()\n        self.conf_metric = ConfusionMatrix(num_classes, normalized)\n\n        if ignore_index is None:\n            self.ignore_index = None\n        elif isinstance(ignore_index, int):\n            self.ignore_index = (ignore_index,)\n        else:\n            try:\n                self.ignore_index = tuple(ignore_index)\n            except TypeError as err:\n                raise ValueError(\"'ignore_index' must be an int or iterable\") from err\n\n    def reset(self):\n        self.conf_metric.reset()\n\n    def add(self, predicted, target):\n        \"\"\"Adds the predicted and target pair to the IoU metric.\n        Keyword arguments:\n        - predicted (Tensor): Can be a (N, K, H, W) tensor of\n        predicted scores obtained from the model for N examples and K classes,\n        or (N, H, W) tensor of integer values between 0 and K-1.\n        - target (Tensor): Can be a (N, K, H, W) tensor of\n        target scores for N examples and K classes, or (N, H, W) tensor of\n        integer values between 0 and K-1.\n        \"\"\"\n        # Dimensions check\n        assert predicted.size(0) == target.size(0), \"number of targets and predicted outputs do not match\"\n        assert (\n            predicted.dim() == 3 or predicted.dim() == 4\n        ), \"predictions must be of dimension (N, H, W) or (N, K, H, W)\"\n        assert target.dim() == 3 or target.dim() == 4, \"targets must be of dimension (N, H, W) or (N, K, H, W)\"\n\n        # If the tensor is in categorical format convert it to integer format\n        if predicted.dim() == 4:\n            _, predicted = predicted.max(1)\n        if target.dim() == 4:\n            _, target = target.max(1)\n\n        self.conf_metric.add(predicted.view(-1), target.view(-1))\n\n    def value(self):\n        \"\"\"Computes the IoU and mean IoU.\n        The mean computation ignores NaN elements of the IoU array.\n        Returns:\n            Tuple: (IoU, mIoU). The first output is the per class IoU,\n            for K classes it's numpy.ndarray with K elements. The second output,\n            is the mean IoU.\n        \"\"\"\n        conf_matrix = self.conf_metric.value()\n        if self.ignore_index is not None:\n            conf_matrix[:, self.ignore_index] = 0\n            conf_matrix[self.ignore_index, :] = 0\n        true_positive = np.diag(conf_matrix)\n        false_positive = np.sum(conf_matrix, 0) - true_positive\n        false_negative = np.sum(conf_matrix, 1) - true_positive\n\n        # Just in case we get a division by 0, ignore/hide the error\n        with np.errstate(divide=\"ignore\", invalid=\"ignore\"):\n            iou = true_positive / (true_positive + false_positive + false_negative)\n\n        return iou, np.nanmean(iou)\n"
  },
  {
    "path": "metric/metric.py",
    "content": "from typing import Protocol\n\n\nclass Metric(Protocol):\n    \"\"\"Base class for all metrics.\"\"\"\n\n    def reset(self):\n        pass\n\n    def add(self, predicted, target):\n        pass\n\n    def value(self):\n        pass\n"
  },
  {
    "path": "metric/precipitation_metrics.py",
    "content": "import torch\nfrom torchmetrics import Metric\nimport numpy as np\n\n\nclass PrecipitationMetrics(Metric):\n    \"\"\"\n    A custom metric for precipitation forecasting that computes both regression metrics (MSE)\n    and classification metrics (precision, recall, F1, etc.) based on a threshold.\n    \"\"\"\n    \n    def __init__(self, threshold=0.5, denormalize=True, dist_sync_on_step=False):\n        \"\"\"\n        Args:\n            threshold: Threshold to convert continuous predictions into binary masks (in mm/h)\n            denormalize: If True, applies the factor to undo normalization\n            dist_sync_on_step: Synchronize metric state across processes at each step\n        \"\"\"\n        super().__init__(dist_sync_on_step=dist_sync_on_step)\n        \n        self.threshold = threshold\n        self.denormalize = denormalize\n        self.factor = 47.83 # Factor to denormalize predictions (mm/h)\n        \n        # Add states for regression metrics\n        self.add_state(\"total_loss\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n        self.add_state(\"total_loss_denorm\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n        self.add_state(\"total_pixels\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n        \n        # Add states for classification metrics\n        self.add_state(\"total_tp\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n        self.add_state(\"total_fp\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n        self.add_state(\"total_tn\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n        self.add_state(\"total_fn\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n    \n    def update(self, preds, target):\n        \"\"\"\n        Update the metric states with new predictions and targets.\n        \n        Args:\n            preds: Model predictions (normalized)\n            target: Ground truth (normalized)\n        \"\"\"\n        # Check for NaN values\n        if torch.isnan(preds).any() or torch.isnan(target).any():\n            print(\"Warning: NaN values detected in predictions or targets\")\n            return\n        \n        # Make sure preds and target have the same shape\n        if preds.shape != target.shape:\n            if len(preds.shape) < len(target.shape):\n                preds = preds.unsqueeze(0)\n            elif len(preds.shape) > len(target.shape):\n                preds = preds.squeeze()\n                # If squeezing made it too small, add dimension back\n                if len(preds.shape) < len(target.shape):\n                    preds = preds.unsqueeze(0)\n        \n        # Calculate MSE loss (normalized)\n        batch_size = target.size(0)\n        loss = torch.nn.functional.mse_loss(preds, target, reduction=\"sum\") / batch_size\n        self.total_loss += loss\n        self.total_samples += batch_size  # Use batch_size instead of 1\n        self.total_pixels += target.numel()\n\n        # Denormalize if needed\n        if self.denormalize:\n            preds_updated = preds * self.factor\n            target_updated = target * self.factor\n\n            # Calculate denormalized MSE loss\n            loss_denorm = torch.nn.functional.mse_loss(preds_updated, target_updated, reduction=\"sum\") / batch_size\n            self.total_loss_denorm += loss_denorm\n        else:\n            preds_updated = preds\n            target_updated = target\n            \n        # Convert to mm/h for classification metrics (multiply by 12 for 5min to hourly rate)\n        preds_hourly = preds_updated * 12\n        target_hourly = target_updated * 12\n        \n        # Apply threshold to get binary masks\n        preds_mask = preds_hourly > self.threshold\n        target_mask = target_hourly > self.threshold\n        \n        # Compute confusion matrix\n        confusion = target_mask.view(-1) * 2 + preds_mask.view(-1)\n        bincount = torch.bincount(confusion, minlength=4)\n        \n        # Update confusion matrix states\n        self.total_tn += bincount[0]\n        self.total_fp += bincount[1]\n        self.total_fn += bincount[2]\n        self.total_tp += bincount[3]\n    \n    def compute(self):\n        \"\"\"\n        Compute the final metrics from the accumulated states.\n        \n        Returns:\n            A dictionary containing all metrics.\n        \"\"\"\n        # Compute regression metrics\n        mse = self.total_loss / self.total_samples\n        mse_denorm = self.total_loss_denorm / self.total_samples if self.denormalize else torch.tensor(float('nan'))\n        mse_pixel = self.total_loss_denorm / self.total_pixels if self.denormalize else torch.tensor(float('nan'))\n        \n        # Compute classification metrics\n        precision = self.total_tp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan'))\n        recall = self.total_tp / (self.total_tp + self.total_fn) if (self.total_tp + self.total_fn) > 0 else torch.tensor(float('nan'))\n        accuracy = (self.total_tp + self.total_tn) / (self.total_tp + self.total_tn + self.total_fp + self.total_fn)\n        \n        # F1 score\n        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(float('nan'))\n        \n        # Critical Success Index (CSI) or Threat Score\n        csi = self.total_tp / (self.total_tp + self.total_fn + self.total_fp) if (self.total_tp + self.total_fn + self.total_fp) > 0 else torch.tensor(float('nan'))\n        \n        # False Alarm Ratio (FAR)\n        far = self.total_fp / (self.total_tp + self.total_fp) if (self.total_tp + self.total_fp) > 0 else torch.tensor(float('nan'))\n        \n        # Heidke Skill Score (HSS)\n        denom = ((self.total_tp + self.total_fn) * (self.total_fn + self.total_tn) + \n                 (self.total_tp + self.total_fp) * (self.total_fp + self.total_tn))\n        hss = ((self.total_tp * self.total_tn) - (self.total_fn * self.total_fp)) / denom if denom > 0 else torch.tensor(float('nan'))\n        \n        return {\n            \"mse\": mse,\n            \"mse_denorm\": mse_denorm,\n            \"mse_pixel\": mse_pixel,\n            \"precision\": precision,\n            \"recall\": recall,\n            \"accuracy\": accuracy,\n            \"f1\": f1,\n            \"csi\": csi,\n            \"far\": far,\n            \"hss\": hss\n        } "
  },
  {
    "path": "models/SmaAt_UNet.py",
    "content": "from torch import nn\nfrom models.unet_parts import OutConv\nfrom models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS\nfrom models.layers import CBAM\n\n\nclass SmaAt_UNet(nn.Module):\n    def __init__(\n        self,\n        n_channels,\n        n_classes,\n        kernels_per_layer=2,\n        bilinear=True,\n        reduction_ratio=16,\n    ):\n        super().__init__()\n        self.n_channels = n_channels\n        self.n_classes = n_classes\n        kernels_per_layer = kernels_per_layer\n        self.bilinear = bilinear\n        reduction_ratio = reduction_ratio\n\n        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)\n        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)\n        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)\n        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)\n        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)\n        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)\n        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)\n        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)\n        factor = 2 if self.bilinear else 1\n        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)\n        self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)\n        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)\n\n        self.outc = OutConv(64, self.n_classes)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x1Att = self.cbam1(x1)\n        x2 = self.down1(x1)\n        x2Att = self.cbam2(x2)\n        x3 = self.down2(x2)\n        x3Att = self.cbam3(x3)\n        x4 = self.down3(x3)\n        x4Att = self.cbam4(x4)\n        x5 = self.down4(x4)\n        x5Att = self.cbam5(x5)\n        x = self.up1(x5Att, x4Att)\n        x = self.up2(x, x3Att)\n        x = self.up3(x, x2Att)\n        x = self.up4(x, x1Att)\n        logits = self.outc(x)\n        return logits\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/layers.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\n# Taken from https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14\nclass DepthToSpace(nn.Module):\n    def __init__(self, block_size):\n        super().__init__()\n        self.bs = block_size\n\n    def forward(self, x):\n        N, C, H, W = x.size()\n        x = x.view(N, self.bs, self.bs, C // (self.bs**2), H, W)  # (N, bs, bs, C//bs^2, H, W)\n        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)\n        x = x.view(N, C // (self.bs**2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)\n        return x\n\n\nclass SpaceToDepth(nn.Module):\n    # Expects the following shape: Batch, Channel, Height, Width\n    def __init__(self, block_size):\n        super().__init__()\n        self.bs = block_size\n\n    def forward(self, x):\n        N, C, H, W = x.size()\n        x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs)  # (N, C, H//bs, bs, W//bs, bs)\n        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)\n        x = x.view(N, C * (self.bs**2), H // self.bs, W // self.bs)  # (N, C*bs^2, H//bs, W//bs)\n        return x\n\n\nclass DepthwiseSeparableConv(nn.Module):\n    def __init__(self, in_channels, output_channels, kernel_size, padding=0, kernels_per_layer=1):\n        super().__init__()\n        # In Tensorflow DepthwiseConv2D has depth_multiplier instead of kernels_per_layer\n        self.depthwise = nn.Conv2d(\n            in_channels,\n            in_channels * kernels_per_layer,\n            kernel_size=kernel_size,\n            padding=padding,\n            groups=in_channels,\n        )\n        self.pointwise = nn.Conv2d(in_channels * kernels_per_layer, output_channels, kernel_size=1)\n\n    def forward(self, x):\n        x = self.depthwise(x)\n        x = self.pointwise(x)\n        return x\n\n\nclass DoubleDense(nn.Module):\n    def __init__(self, in_channels, hidden_neurons, output_channels):\n        super().__init__()\n        self.dense1 = nn.Linear(in_channels, out_features=hidden_neurons)\n        self.dense2 = nn.Linear(in_features=hidden_neurons, out_features=hidden_neurons // 2)\n        self.dense3 = nn.Linear(in_features=hidden_neurons // 2, out_features=output_channels)\n\n    def forward(self, x):\n        out = F.relu(self.dense1(x.view(x.size(0), -1)))\n        out = F.relu(self.dense2(out))\n        out = self.dense3(out)\n        return out\n\n\nclass DoubleDSConv(nn.Module):\n    \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.double_ds_conv = nn.Sequential(\n            DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU(inplace=True),\n            DepthwiseSeparableConv(out_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU(inplace=True),\n        )\n\n    def forward(self, x):\n        return self.double_ds_conv(x)\n\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n\n\nclass ChannelAttention(nn.Module):\n    def __init__(self, input_channels, reduction_ratio=16):\n        super().__init__()\n        self.input_channels = input_channels\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.max_pool = nn.AdaptiveMaxPool2d(1)\n        #  https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py\n        #  uses Convolutions instead of Linear\n        self.MLP = nn.Sequential(\n            Flatten(),\n            nn.Linear(input_channels, input_channels // reduction_ratio),\n            nn.ReLU(),\n            nn.Linear(input_channels // reduction_ratio, input_channels),\n        )\n\n    def forward(self, x):\n        # Take the input and apply average and max pooling\n        avg_values = self.avg_pool(x)\n        max_values = self.max_pool(x)\n        out = self.MLP(avg_values) + self.MLP(max_values)\n        scale = x * torch.sigmoid(out).unsqueeze(2).unsqueeze(3).expand_as(x)\n        return scale\n\n\nclass SpatialAttention(nn.Module):\n    def __init__(self, kernel_size=7):\n        super().__init__()\n        assert kernel_size in (3, 7), \"kernel size must be 3 or 7\"\n        padding = 3 if kernel_size == 7 else 1\n        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)\n        self.bn = nn.BatchNorm2d(1)\n\n    def forward(self, x):\n        avg_out = torch.mean(x, dim=1, keepdim=True)\n        max_out, _ = torch.max(x, dim=1, keepdim=True)\n        out = torch.cat([avg_out, max_out], dim=1)\n        out = self.conv(out)\n        out = self.bn(out)\n        scale = x * torch.sigmoid(out)\n        return scale\n\n\nclass CBAM(nn.Module):\n    def __init__(self, input_channels, reduction_ratio=16, kernel_size=7):\n        super().__init__()\n        self.channel_att = ChannelAttention(input_channels, reduction_ratio=reduction_ratio)\n        self.spatial_att = SpatialAttention(kernel_size=kernel_size)\n\n    def forward(self, x):\n        out = self.channel_att(x)\n        out = self.spatial_att(out)\n        return out\n"
  },
  {
    "path": "models/regression_lightning.py",
    "content": "import lightning.pytorch as pl\nimport torch\nfrom torch import nn, optim, Tensor\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.sampler import SubsetRandomSampler\nfrom utils import dataset_precip\nimport argparse\nimport numpy as np\nfrom metric.precipitation_metrics import PrecipitationMetrics\nfrom utils.formatting import make_metrics_str\n\nclass UNetBase(pl.LightningModule):\n    @staticmethod\n    def add_model_specific_args(parent_parser):\n        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)\n        parser.add_argument(\n            \"--model\",\n            type=str,\n            default=\"UNet\",\n            choices=[\"UNet\", \"UNetDS\", \"UNetAttention\", \"UNetDSAttention\", \"PersistenceModel\"],\n        )\n        parser.add_argument(\"--n_channels\", type=int, default=12)\n        parser.add_argument(\"--n_classes\", type=int, default=1)\n        parser.add_argument(\"--kernels_per_layer\", type=int, default=1)\n        parser.add_argument(\"--bilinear\", type=bool, default=True)\n        parser.add_argument(\"--reduction_ratio\", type=int, default=16)\n        parser.add_argument(\"--lr_patience\", type=int, default=5)\n        parser.add_argument(\"--threshold\", type=float, default=0.5)\n        return parser\n\n    def __init__(self, hparams):\n        super().__init__()\n        self.save_hyperparameters(hparams)\n        self.train_metrics = PrecipitationMetrics(\n            threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5\n        )\n        self.val_metrics = PrecipitationMetrics(\n            threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5\n        )\n        self.test_metrics = PrecipitationMetrics(\n            threshold=self.hparams.threshold if hasattr(self.hparams, 'threshold') else 0.5\n        )\n\n    def forward(self, x):\n        pass\n\n    def configure_optimizers(self):\n        opt = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n        scheduler = {\n            \"scheduler\": optim.lr_scheduler.ReduceLROnPlateau(\n                opt, mode=\"min\", factor=0.1, patience=self.hparams.lr_patience\n            ),\n            \"monitor\": \"val_loss\",  # Default: val_loss\n        }\n        return [opt], [scheduler]\n\n    def loss_func(self, y_pred, y_true):\n        # Ensure consistent shapes before computing loss\n        if y_pred.dim() > y_true.dim():\n            y_pred = y_pred.squeeze(1)\n        elif y_true.dim() > y_pred.dim():\n            y_pred = y_pred.unsqueeze(1)\n\n        # reduction=\"mean\" is average of every pixel, but I want average of image\n        return nn.functional.mse_loss(y_pred, y_true, reduction=\"sum\") / y_true.size(0)\n\n    def training_step(self, batch, batch_idx):\n        x, y = batch\n        y_pred = self(x)\n        loss = self.loss_func(y_pred, y)\n\n        self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n\n        # Update training metrics with detached tensors\n        self.train_metrics.update(y_pred.detach(), y.detach())\n        \n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        x, y = batch\n        y_pred = self(x)\n        loss = self.loss_func(y_pred, y)\n        self.log(\"val_loss\", loss, prog_bar=True)\n        \n        # Update validation metrics with detached tensors\n        self.val_metrics.update(y_pred.detach(), y.detach())\n\n    def test_step(self, batch, batch_idx):\n        \"\"\"Calculate the loss (MSE per default) and other metrics on the test set normalized and denormalized.\"\"\"\n        x, y = batch\n        y_pred = self(x)\n\n        # Update the test metrics\n        self.test_metrics.update(y_pred.detach(), y.detach())\n    \n    def on_test_epoch_end(self):\n        \"\"\"Compute and log all metrics at the end of the test epoch.\"\"\"\n        test_metrics_dict = self.test_metrics.compute()\n        \n        # Print all metrics in one line\n        print(f\"\\n\\nEpoch {self.current_epoch} - Test Metrics: {make_metrics_str(test_metrics_dict)}\")\n        \n        # Reset the metrics for the next test epoch\n        self.test_metrics.reset()\n\n    def on_train_epoch_end(self):\n        \"\"\"Compute and log all metrics at the end of the training epoch.\"\"\"\n        train_metrics_dict = self.train_metrics.compute()\n        \n        print(f\"\\n\\nEpoch {self.current_epoch} - Train Metrics: {make_metrics_str(train_metrics_dict)}\")\n        self.train_metrics.reset()\n\n    def on_validation_epoch_end(self):\n        \"\"\"Compute and log all metrics at the end of the validation epoch.\"\"\"\n        val_metrics_dict = self.val_metrics.compute()\n        \n        print(f\"\\n\\nEpoch {self.current_epoch} - Validation Metrics: {make_metrics_str(val_metrics_dict)}\")\n        self.val_metrics.reset()\n\n\nclass PrecipRegressionBase(UNetBase):\n    @staticmethod\n    def add_model_specific_args(parent_parser):\n        parent_parser = UNetBase.add_model_specific_args(parent_parser)\n        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)\n        parser.add_argument(\"--num_input_images\", type=int, default=12)\n        parser.add_argument(\"--num_output_images\", type=int, default=6)\n        parser.add_argument(\"--valid_size\", type=float, default=0.1)\n        parser.add_argument(\"--use_oversampled_dataset\", type=bool, default=True)\n        parser.n_channels = parser.parse_args().num_input_images\n        parser.n_classes = 1\n        return parser\n\n    def __init__(self, hparams):\n        super().__init__(hparams=hparams)\n        self.train_dataset = None\n        self.valid_dataset = None\n        self.train_sampler = None\n        self.valid_sampler = None\n\n    def prepare_data(self):\n        # train_transform = transforms.Compose([\n        #     transforms.RandomHorizontalFlip()]\n        # )\n        train_transform = None\n        valid_transform = None\n        precip_dataset = (\n            dataset_precip.precipitation_maps_oversampled_h5\n            if self.hparams.use_oversampled_dataset\n            else dataset_precip.precipitation_maps_h5\n        )\n        self.train_dataset = precip_dataset(\n            in_file=self.hparams.dataset_folder,\n            num_input_images=self.hparams.num_input_images,\n            num_output_images=self.hparams.num_output_images,\n            train=True,\n            transform=train_transform,\n        )\n        self.valid_dataset = precip_dataset(\n            in_file=self.hparams.dataset_folder,\n            num_input_images=self.hparams.num_input_images,\n            num_output_images=self.hparams.num_output_images,\n            train=True,\n            transform=valid_transform,\n        )\n\n        num_train = len(self.train_dataset)\n        indices = list(range(num_train))\n        split = int(np.floor(self.hparams.valid_size * num_train))\n\n        np.random.shuffle(indices)\n\n        train_idx, valid_idx = indices[split:], indices[:split]\n        self.train_sampler = SubsetRandomSampler(train_idx)\n        self.valid_sampler = SubsetRandomSampler(valid_idx)\n\n    def train_dataloader(self):\n        train_loader = DataLoader(\n            self.train_dataset,\n            batch_size=self.hparams.batch_size,\n            sampler=self.train_sampler,\n            pin_memory=True,\n            # The following can/should be tweaked depending on the number of CPU cores\n            num_workers=1,\n            persistent_workers=True,\n        )\n        return train_loader\n\n    def val_dataloader(self):\n        valid_loader = DataLoader(\n            self.valid_dataset,\n            batch_size=self.hparams.batch_size,\n            sampler=self.valid_sampler,\n            pin_memory=True,\n            # The following can/should be tweaked depending on the number of CPU cores\n            num_workers=1,\n            persistent_workers=True,\n        )\n        return valid_loader\n\n\nclass PersistenceModel(UNetBase):\n    def forward(self, x):\n        return x[:, -1:, :, :]\n"
  },
  {
    "path": "models/unet_parts.py",
    "content": "\"\"\" Parts of the U-Net model \"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DoubleConv(nn.Module):\n    \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n\n    def __init__(self, in_channels, out_channels, mid_channels=None):\n        super().__init__()\n        if not mid_channels:\n            mid_channels = out_channels\n        self.double_conv = nn.Sequential(\n            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(mid_channels),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU(inplace=True),\n        )\n\n    def forward(self, x):\n        return self.double_conv(x)\n\n\nclass Down(nn.Module):\n    \"\"\"Downscaling with maxpool then double conv\"\"\"\n\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))\n\n    def forward(self, x):\n        return self.maxpool_conv(x)\n\n\nclass Up(nn.Module):\n    \"\"\"Upscaling then double conv\"\"\"\n\n    def __init__(self, in_channels, out_channels, bilinear=True):\n        super().__init__()\n\n        # if bilinear, use the normal convolutions to reduce the number of channels\n        if bilinear:\n            self.up = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=True)\n            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)\n        else:\n            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)\n            self.conv = DoubleConv(in_channels, out_channels)\n\n    def forward(self, x1, x2):\n        x1 = self.up(x1)\n        # input is CHW\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n\n        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])\n        # if you have padding issues, see\n        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n        x = torch.cat([x2, x1], dim=1)\n        return self.conv(x)\n\n\nclass OutConv(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)\n\n    def forward(self, x):\n        return self.conv(x)\n"
  },
  {
    "path": "models/unet_parts_depthwise_separable.py",
    "content": "\"\"\" Parts of the U-Net model \"\"\"\n\n# Base model taken from: https://github.com/milesial/Pytorch-UNet\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom models.layers import DepthwiseSeparableConv\n\n\nclass DoubleConvDS(nn.Module):\n    \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n\n    def __init__(self, in_channels, out_channels, mid_channels=None, kernels_per_layer=1):\n        super().__init__()\n        if not mid_channels:\n            mid_channels = out_channels\n        self.double_conv = nn.Sequential(\n            DepthwiseSeparableConv(\n                in_channels,\n                mid_channels,\n                kernel_size=3,\n                kernels_per_layer=kernels_per_layer,\n                padding=1,\n            ),\n            nn.BatchNorm2d(mid_channels),\n            nn.ReLU(inplace=True),\n            DepthwiseSeparableConv(\n                mid_channels,\n                out_channels,\n                kernel_size=3,\n                kernels_per_layer=kernels_per_layer,\n                padding=1,\n            ),\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU(inplace=True),\n        )\n\n    def forward(self, x):\n        return self.double_conv(x)\n\n\nclass DownDS(nn.Module):\n    \"\"\"Downscaling with maxpool then double conv\"\"\"\n\n    def __init__(self, in_channels, out_channels, kernels_per_layer=1):\n        super().__init__()\n        self.maxpool_conv = nn.Sequential(\n            nn.MaxPool2d(2),\n            DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer),\n        )\n\n    def forward(self, x):\n        return self.maxpool_conv(x)\n\n\nclass UpDS(nn.Module):\n    \"\"\"Upscaling then double conv\"\"\"\n\n    def __init__(self, in_channels, out_channels, bilinear=True, kernels_per_layer=1):\n        super().__init__()\n\n        # if bilinear, use the normal convolutions to reduce the number of channels\n        if bilinear:\n            self.up = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=True)\n            self.conv = DoubleConvDS(\n                in_channels,\n                out_channels,\n                in_channels // 2,\n                kernels_per_layer=kernels_per_layer,\n            )\n        else:\n            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)\n            self.conv = DoubleConvDS(in_channels, out_channels, kernels_per_layer=kernels_per_layer)\n\n    def forward(self, x1, x2):\n        x1 = self.up(x1)\n        # input is CHW\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n\n        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])\n        # if you have padding issues, see\n        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n        x = torch.cat([x2, x1], dim=1)\n        return self.conv(x)\n\n\nclass OutConv(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)\n\n    def forward(self, x):\n        return self.conv(x)\n"
  },
  {
    "path": "models/unet_precip_regression_lightning.py",
    "content": "from models.unet_parts import Down, DoubleConv, Up, OutConv\nfrom models.unet_parts_depthwise_separable import DoubleConvDS, UpDS, DownDS\nfrom models.layers import CBAM\nfrom models.regression_lightning import PrecipRegressionBase, PersistenceModel\n\n\nclass UNet(PrecipRegressionBase):\n    def __init__(self, hparams):\n        super().__init__(hparams=hparams)\n        self.n_channels = self.hparams.n_channels\n        self.n_classes = self.hparams.n_classes\n        self.bilinear = self.hparams.bilinear\n\n        self.inc = DoubleConv(self.n_channels, 64)\n        self.down1 = Down(64, 128)\n        self.down2 = Down(128, 256)\n        self.down3 = Down(256, 512)\n        factor = 2 if self.bilinear else 1\n        self.down4 = Down(512, 1024 // factor)\n        self.up1 = Up(1024, 512 // factor, self.bilinear)\n        self.up2 = Up(512, 256 // factor, self.bilinear)\n        self.up3 = Up(256, 128 // factor, self.bilinear)\n        self.up4 = Up(128, 64, self.bilinear)\n\n        self.outc = OutConv(64, self.n_classes)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x2 = self.down1(x1)\n        x3 = self.down2(x2)\n        x4 = self.down3(x3)\n        x5 = self.down4(x4)\n        x = self.up1(x5, x4)\n        x = self.up2(x, x3)\n        x = self.up3(x, x2)\n        x = self.up4(x, x1)\n        logits = self.outc(x)\n        return logits\n\n\nclass UNetAttention(PrecipRegressionBase):\n    def __init__(self, hparams):\n        super().__init__(hparams=hparams)\n        self.n_channels = self.hparams.n_channels\n        self.n_classes = self.hparams.n_classes\n        self.bilinear = self.hparams.bilinear\n        reduction_ratio = self.hparams.reduction_ratio\n\n        self.inc = DoubleConv(self.n_channels, 64)\n        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)\n        self.down1 = Down(64, 128)\n        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)\n        self.down2 = Down(128, 256)\n        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)\n        self.down3 = Down(256, 512)\n        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)\n        factor = 2 if self.bilinear else 1\n        self.down4 = Down(512, 1024 // factor)\n        self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)\n        self.up1 = Up(1024, 512 // factor, self.bilinear)\n        self.up2 = Up(512, 256 // factor, self.bilinear)\n        self.up3 = Up(256, 128 // factor, self.bilinear)\n        self.up4 = Up(128, 64, self.bilinear)\n\n        self.outc = OutConv(64, self.n_classes)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x1Att = self.cbam1(x1)\n        x2 = self.down1(x1)\n        x2Att = self.cbam2(x2)\n        x3 = self.down2(x2)\n        x3Att = self.cbam3(x3)\n        x4 = self.down3(x3)\n        x4Att = self.cbam4(x4)\n        x5 = self.down4(x4)\n        x5Att = self.cbam5(x5)\n        x = self.up1(x5Att, x4Att)\n        x = self.up2(x, x3Att)\n        x = self.up3(x, x2Att)\n        x = self.up4(x, x1Att)\n        logits = self.outc(x)\n        return logits\n\n\nclass UNetDS(PrecipRegressionBase):\n    def __init__(self, hparams):\n        super().__init__(hparams=hparams)\n        self.n_channels = self.hparams.n_channels\n        self.n_classes = self.hparams.n_classes\n        self.bilinear = self.hparams.bilinear\n        kernels_per_layer = self.hparams.kernels_per_layer\n\n        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)\n        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)\n        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)\n        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)\n        factor = 2 if self.bilinear else 1\n        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)\n        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)\n\n        self.outc = OutConv(64, self.n_classes)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x2 = self.down1(x1)\n        x3 = self.down2(x2)\n        x4 = self.down3(x3)\n        x5 = self.down4(x4)\n        x = self.up1(x5, x4)\n        x = self.up2(x, x3)\n        x = self.up3(x, x2)\n        x = self.up4(x, x1)\n        logits = self.outc(x)\n        return logits\n\n\nclass UNetDSAttention(PrecipRegressionBase):\n    def __init__(self, hparams):\n        super().__init__(hparams=hparams)\n        self.n_channels = self.hparams.n_channels\n        self.n_classes = self.hparams.n_classes\n        self.bilinear = self.hparams.bilinear\n        reduction_ratio = self.hparams.reduction_ratio\n        kernels_per_layer = self.hparams.kernels_per_layer\n\n        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)\n        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)\n        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)\n        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)\n        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)\n        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)\n        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)\n        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)\n        factor = 2 if self.bilinear else 1\n        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)\n        self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)\n        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)\n\n        self.outc = OutConv(64, self.n_classes)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x1Att = self.cbam1(x1)\n        x2 = self.down1(x1)\n        x2Att = self.cbam2(x2)\n        x3 = self.down2(x2)\n        x3Att = self.cbam3(x3)\n        x4 = self.down3(x3)\n        x4Att = self.cbam4(x4)\n        x5 = self.down4(x4)\n        x5Att = self.cbam5(x5)\n        x = self.up1(x5Att, x4Att)\n        x = self.up2(x, x3Att)\n        x = self.up3(x, x2Att)\n        x = self.up4(x, x1Att)\n        logits = self.outc(x)\n        return logits\n\n\nclass UNetDSAttention4CBAMs(PrecipRegressionBase):\n    def __init__(self, hparams):\n        super().__init__(hparams=hparams)\n        self.n_channels = self.hparams.n_channels\n        self.n_classes = self.hparams.n_classes\n        self.bilinear = self.hparams.bilinear\n        reduction_ratio = self.hparams.reduction_ratio\n        kernels_per_layer = self.hparams.kernels_per_layer\n\n        self.inc = DoubleConvDS(self.n_channels, 64, kernels_per_layer=kernels_per_layer)\n        self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)\n        self.down1 = DownDS(64, 128, kernels_per_layer=kernels_per_layer)\n        self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)\n        self.down2 = DownDS(128, 256, kernels_per_layer=kernels_per_layer)\n        self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)\n        self.down3 = DownDS(256, 512, kernels_per_layer=kernels_per_layer)\n        self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)\n        factor = 2 if self.bilinear else 1\n        self.down4 = DownDS(512, 1024 // factor, kernels_per_layer=kernels_per_layer)\n        self.up1 = UpDS(1024, 512 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up2 = UpDS(512, 256 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up3 = UpDS(256, 128 // factor, self.bilinear, kernels_per_layer=kernels_per_layer)\n        self.up4 = UpDS(128, 64, self.bilinear, kernels_per_layer=kernels_per_layer)\n\n        self.outc = OutConv(64, self.n_classes)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x1Att = self.cbam1(x1)\n        x2 = self.down1(x1)\n        x2Att = self.cbam2(x2)\n        x3 = self.down2(x2)\n        x3Att = self.cbam3(x3)\n        x4 = self.down3(x3)\n        x4Att = self.cbam4(x4)\n        x5 = self.down4(x4)\n        x = self.up1(x5, x4Att)\n        x = self.up2(x, x3Att)\n        x = self.up3(x, x2Att)\n        x = self.up4(x, x1Att)\n        logits = self.outc(x)\n        return logits\n"
  },
  {
    "path": "plot_examples.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"from torch import nn\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import numpy as np\\n\",\n    \"import os\\n\",\n    \"import models.unet_precip_regression_lightning as unet_regr\\n\",\n    \"from utils import dataset_precip\\n\",\n    \"\\n\",\n    \"dataset = dataset_precip.precipitation_maps_oversampled_h5(\\n\",\n    \"            in_file=\\\"data/precipitation/train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5\\\",\\n\",\n    \"            num_input_images=12,\\n\",\n    \"            num_output_images=6, train=False)\\n\",\n    \"# 222, 444, 777, 1337 is fine\\n\",\n    \"x, y_true = dataset[777]\\n\",\n    \"# add batch dimension\\n\",\n    \"x = torch.tensor(x).unsqueeze(0)\\n\",\n    \"y_true = torch.tensor(y_true).unsqueeze(0)\\n\",\n    \"print(x.shape, y_true.shape)\\n\",\n    \"\\n\",\n    \"models_dir = \\\"checkpoints/comparison\\\"\\n\",\n    \"precip_models = [m for m in os.listdir(models_dir) if \\\".ckpt\\\" in m]\\n\",\n    \"print(precip_models)\\n\",\n    \"\\n\",\n    \"loss_func = nn.MSELoss(reduction=\\\"sum\\\")\\n\",\n    \"print(\\\"Persistence\\\", loss_func(x.squeeze()[-1]*47.83, y_true.squeeze()*47.83).item())\\n\",\n    \"\\n\",\n    \"plot_images = dict()\\n\",\n    \"for i, model_file in enumerate(precip_models):\\n\",\n    \"    # model_name = \\\"\\\".join(model_file.split(\\\"_\\\")[0])\\n\",\n    \"    if \\\"UNetAttention\\\" in model_file:\\n\",\n    \"        model_name = \\\"UNet with CBAM\\\"\\n\",\n    \"        model = unet_regr.UNetAttention\\n\",\n    \"    elif \\\"UNetDSAttention4CBAMs\\\" in model_file:\\n\",\n    \"        model_name = \\\"UNetDS Attention 4CBAMs\\\"\\n\",\n    \"        model = unet_regr.UNetDSAttention4CBAMs\\n\",\n    \"    elif \\\"UNetDSAttention\\\" in model_file:\\n\",\n    \"        model_name = \\\"SmaAt-UNet\\\"\\n\",\n    \"        model = unet_regr.UNetDSAttention\\n\",\n    \"    elif \\\"UNetDS\\\" in model_file:\\n\",\n    \"        model_name = \\\"UNet with DSC\\\"\\n\",\n    \"        model = unet_regr.UNetDS\\n\",\n    \"    elif \\\"UNet\\\" in model_file:\\n\",\n    \"        model_name = \\\"UNet\\\"\\n\",\n    \"        model = unet_regr.UNet\\n\",\n    \"    else:\\n\",\n    \"        raise NotImplementedError(f\\\"Model not found\\\")\\n\",\n    \"    model = model.load_from_checkpoint(f\\\"{models_dir}/{model_file}\\\")\\n\",\n    \"    # loaded = torch.load(f\\\"{models_dir}/{model_file}\\\")\\n\",\n    \"    # model = loaded[\\\"model\\\"]\\n\",\n    \"    # model.load_state_dict(loaded[\\\"state_dict\\\"])\\n\",\n    \"    # loss_func = nn.CrossEntropyLoss(weight=torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.]))\\n\",\n    \"\\n\",\n    \"    model.to(\\\"cpu\\\")\\n\",\n    \"    model.eval()\\n\",\n    \"\\n\",\n    \"    # get prediction of model\\n\",\n    \"    with torch.no_grad():\\n\",\n    \"        y_pred = model(x)\\n\",\n    \"    y_pred = y_pred.squeeze()\\n\",\n    \"    print(model_name, loss_func(y_pred*47.83, y_true.squeeze()*47.83).item())\\n\",\n    \"\\n\",\n    \"    plot_images[model_name] = y_pred\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"#### Plotting ####\\n\",\n    \"plt.figure(figsize=(10, 20))\\n\",\n    \"plt.subplot(len(precip_models)+2, 1, 1)\\n\",\n    \"plt.imshow(y_true.squeeze()*47.83)\\n\",\n    \"plt.colorbar()\\n\",\n    \"plt.title(\\\"Ground truth\\\")\\n\",\n    \"plt.axis(\\\"off\\\")\\n\",\n    \"\\n\",\n    \"plt.subplot(len(precip_models)+2, 1, 2)\\n\",\n    \"plt.imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\\n\",\n    \"plt.colorbar()\\n\",\n    \"plt.title(\\\"Persistence\\\")\\n\",\n    \"plt.axis(\\\"off\\\")\\n\",\n    \"\\n\",\n    \"for i, (model_name, y_pred) in enumerate(plot_images.items()):\\n\",\n    \"    print(model_name)\\n\",\n    \"    plt.subplot(len(precip_models)+2, 1, i+3)\\n\",\n    \"    plt.imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\\n\",\n    \"    plt.colorbar()\\n\",\n    \"    plt.title(model_name)\\n\",\n    \"    plt.axis(\\\"off\\\")\\n\",\n    \"# plt.colorbar()\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"fig, axesgrid = plt.subplots(nrows=2, ncols=(len(precip_models)+2)//2, figsize=(20, 10))\\n\",\n    \"axes = axesgrid.flat\\n\",\n    \"im = axes[0].imshow(y_true.squeeze()*47.83)\\n\",\n    \"axes[0].set_title(\\\"Groundtruth\\\", fontsize=15)\\n\",\n    \"axes[0].set_axis_off()\\n\",\n    \"axes[1].imshow(x.squeeze()[-1]*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\\n\",\n    \"axes[1].set_title(\\\"Persistence\\\", fontsize=15)\\n\",\n    \"axes[1].set_axis_off()\\n\",\n    \"\\n\",\n    \"# for ax in axes.flat[2:]:\\n\",\n    \"for i, (model_name, y_pred) in enumerate(plot_images.items()):\\n\",\n    \"    print(model_name)\\n\",\n    \"#     plt.subplot(len(precip_models)+2, 1, i+3)\\n\",\n    \"    axes[i+2].imshow(y_pred*47.83, vmax=torch.max(y_true*47.83), vmin=torch.min(y_true*47.83))\\n\",\n    \"    axes[i+2].set_title(model_name, fontsize=15)\\n\",\n    \"    axes[i+2].set_axis_off()\\n\",\n    \"#     im = ax.imshow(np.random.random((16, 16)), cmap='viridis',\\n\",\n    \"#                    vmin=0, vmax=1)\\n\",\n    \"\\n\",\n    \"output_dir = \\\"plots\\\"\\n\",\n    \"os.makedirs(output_dir, exist_ok=True)\\n\",\n    \"\\n\",\n    \"fig.subplots_adjust(wspace=0.02)\\n\",\n    \"cbar = fig.colorbar(im, ax=axesgrid.ravel().tolist(), shrink=1)\\n\",\n    \"cbar.ax.set_ylabel('mm/5min', fontsize=15)\\n\",\n    \"plt.savefig(os.path.join(output_dir, \\\"Precip_example.eps\\\"))\\n\",\n    \"plt.savefig(os.path.join(output_dir, \\\"Precip_example.png\\\"), dpi=300)\\n\",\n    \"plt.show()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \".venv\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 2\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython2\",\n   \"version\": \"3.10.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"smaat-unet\"\nversion = \"0.1.0\"\ndescription = \"Code for the paper `SmaAt-UNet: Precipitation Nowcasting using a Small Attention-UNet Architecture`\"\nauthors = [{ name = \"Kevin Trebing\", email = \"Kevin.Trebing@gmx.net\" }]\nreadme = \"README.md\"\nrequires-python = \">=3.10.11\"\ndependencies = [\n    \"h5py>=3.13.0\",\n    \"lightning>=2.5.0.post0\",\n    \"matplotlib>=3.10.0\",\n    \"numpy>=2.2.3\",\n    \"tensorboard>=2.19.0\",\n    \"torch>=2.6.0\",\n    \"torchsummary>=1.5.1\",\n    \"tqdm>=4.67.1\",\n]\n\n[dependency-groups]\ndev = [\n    \"mypy>=1.15.0\",\n    \"pre-commit>=4.1.0\",\n    \"ruff>=0.9.7\",\n]\nscripts = [\n    \"ipykernel>=6.29.5\",\n]\n\n[[tool.uv.index]]\nname = \"pytorch-cu126\"\nurl = \"https://download.pytorch.org/whl/cu126\"\nexplicit = true\n\n[tool.uv.sources]\ntorch = [\n  { index = \"pytorch-cu126\", marker = \"sys_platform == 'linux' or sys_platform == 'win32'\" },\n]\ntorchvision = [\n  { index = \"pytorch-cu126\", marker = \"sys_platform == 'linux' or sys_platform == 'win32'\" },\n]\n\n[tool.mypy]\npython_version = \"3.10\"\nignore_missing_imports = true\n\n[tool.ruff]\nline-length = 120\n\n[tool.ruff.lint]\nselect = [\n    \"E\",   # pycodestyle Errors\n    \"W\",   # pycodestyle Warnings\n    \"F\",   # pyflakes\n    \"UP\",  # pyupgrade\n#    \"D\",   # pydocstyle\n    \"B\",   # flake8-bugbear\n    \"C4\",  # flake8-comprehensions\n    \"A\",   # flake8-builtins\n    \"C90\", # mccabe complexity\n    \"DJ\",  # flake8-django\n    \"PIE\", # flake8-pie\n#    \"SIM\", # flake8-simplify\n]\nignore = [\n    \"B905\", # length-checking in zip() only introduced with Python 3.10 (PEP618)\n    \"UP007\",  # Python version 3.9 does not allow writing union types as X | Y\n    \"UP038\",  # Python version 3.9 does not allow writing union types as X | Y\n    \"D202\",  # No blank lines allowed after function docstring\n    \"D100\",  # Missing docstring in public module\n    \"D104\",  # Missing docstring in public package\n    \"D205\",  # 1 blank line required between summary line and description\n    \"D203\",  # One blank line before class docstring\n    \"D212\",  # Multi-line docstring summary should start at the first line\n    \"D213\",  # Multi-line docstring summary should start at the second line\n    \"D400\",  # First line of docstring should end with a period\n    \"D404\",  # First word of the docstring should not be \"This\"\n    \"D415\",  # First line should end with a period, question mark, or exclamation point\n    \"DJ001\", # Avoid using `null=True` on string-based fields\n]\nexclude = [\n    \".git\",\n    \".local\",\n    \".cache\",\n    \".venv\",\n    \"./venv\",\n    \".vscode\",\n    \"__pycache__\",\n    \"docs\",\n    \"build\",\n    \"dist\",\n    \"notebooks\",\n    \"migrations\"\n]\n\n[tool.ruff.lint.mccabe]\n# Flag errors (`C901`) whenever the complexity level exceeds 12.\nmax-complexity = 12\n"
  },
  {
    "path": "requirements-dev.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv export --format=requirements-txt --group dev --no-hashes --no-sources --output-file=requirements-dev.txt\nabsl-py==2.1.0\naiohappyeyeballs==2.4.6\naiohttp==3.11.12\naiosignal==1.3.2\nasync-timeout==5.0.1 ; python_full_version < '3.11'\nattrs==25.1.0\ncfgv==3.4.0\ncolorama==0.4.6 ; sys_platform == 'win32'\ncontourpy==1.3.1\ncycler==0.12.1\ndistlib==0.3.9\nfilelock==3.17.0\nfonttools==4.56.0\nfrozenlist==1.5.0\nfsspec==2025.2.0\ngrpcio==1.70.0\nh5py==3.13.0\nidentify==2.6.7\nidna==3.10\njinja2==3.1.5\nkiwisolver==1.4.8\nlightning==2.5.0.post0\nlightning-utilities==0.12.0\nmarkdown==3.7\nmarkupsafe==3.0.2\nmatplotlib==3.10.0\nmpmath==1.3.0\nmultidict==6.1.0\nmypy==1.15.0\nmypy-extensions==1.0.0\nnetworkx==3.4.2\nnodeenv==1.9.1\nnumpy==2.2.3\nnvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\npackaging==24.2\npillow==11.1.0\nplatformdirs==4.3.6\npre-commit==4.1.0\npropcache==0.3.0\nprotobuf==5.29.3\npyparsing==3.2.1\npython-dateutil==2.9.0.post0\npytorch-lightning==2.5.0.post0\npyyaml==6.0.2\nruff==0.9.7\nsetuptools==75.8.0\nsix==1.17.0\nsympy==1.13.1\ntensorboard==2.19.0\ntensorboard-data-server==0.7.2\ntomli==2.2.1 ; python_full_version < '3.11'\ntorch==2.6.0\ntorchmetrics==1.6.1\ntorchsummary==1.5.1\ntqdm==4.67.1\ntriton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'\ntyping-extensions==4.12.2\nvirtualenv==20.29.2\nwerkzeug==3.1.3\nyarl==1.18.3\n"
  },
  {
    "path": "requirements-full.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv export --format=requirements-txt --all-groups --no-hashes --no-sources --output-file=requirements-full.txt\nabsl-py==2.1.0\naiohappyeyeballs==2.4.6\naiohttp==3.11.12\naiosignal==1.3.2\nappnope==0.1.4 ; sys_platform == 'darwin'\nasttokens==3.0.0\nasync-timeout==5.0.1 ; python_full_version < '3.11'\nattrs==25.1.0\ncffi==1.17.1 ; implementation_name == 'pypy'\ncfgv==3.4.0\ncolorama==0.4.6 ; sys_platform == 'win32'\ncomm==0.2.2\ncontourpy==1.3.1\ncycler==0.12.1\ndebugpy==1.8.12\ndecorator==5.1.1\ndistlib==0.3.9\nexceptiongroup==1.2.2 ; python_full_version < '3.11'\nexecuting==2.2.0\nfilelock==3.17.0\nfonttools==4.56.0\nfrozenlist==1.5.0\nfsspec==2025.2.0\ngrpcio==1.70.0\nh5py==3.13.0\nidentify==2.6.7\nidna==3.10\nipykernel==6.29.5\nipython==8.32.0\njedi==0.19.2\njinja2==3.1.5\njupyter-client==8.6.3\njupyter-core==5.7.2\nkiwisolver==1.4.8\nlightning==2.5.0.post0\nlightning-utilities==0.12.0\nmarkdown==3.7\nmarkupsafe==3.0.2\nmatplotlib==3.10.0\nmatplotlib-inline==0.1.7\nmpmath==1.3.0\nmultidict==6.1.0\nmypy==1.15.0\nmypy-extensions==1.0.0\nnest-asyncio==1.6.0\nnetworkx==3.4.2\nnodeenv==1.9.1\nnumpy==2.2.3\nnvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\npackaging==24.2\nparso==0.8.4\npexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'\npillow==11.1.0\nplatformdirs==4.3.6\npre-commit==4.1.0\nprompt-toolkit==3.0.50\npropcache==0.3.0\nprotobuf==5.29.3\npsutil==7.0.0\nptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'\npure-eval==0.2.3\npycparser==2.22 ; implementation_name == 'pypy'\npygments==2.19.1\npyparsing==3.2.1\npython-dateutil==2.9.0.post0\npytorch-lightning==2.5.0.post0\npywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'\npyyaml==6.0.2\npyzmq==26.2.1\nruff==0.9.7\nsetuptools==75.8.0\nsix==1.17.0\nstack-data==0.6.3\nsympy==1.13.1\ntensorboard==2.19.0\ntensorboard-data-server==0.7.2\ntomli==2.2.1 ; python_full_version < '3.11'\ntorch==2.6.0\ntorchmetrics==1.6.1\ntorchsummary==1.5.1\ntornado==6.4.2\ntqdm==4.67.1\ntraitlets==5.14.3\ntriton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'\ntyping-extensions==4.12.2\nvirtualenv==20.29.2\nwcwidth==0.2.13\nwerkzeug==3.1.3\nyarl==1.18.3\n"
  },
  {
    "path": "requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv export --format=requirements-txt --no-dev --no-hashes --no-sources --output-file=requirements.txt\nabsl-py==2.1.0\naiohappyeyeballs==2.4.6\naiohttp==3.11.12\naiosignal==1.3.2\nasync-timeout==5.0.1 ; python_full_version < '3.11'\nattrs==25.1.0\ncolorama==0.4.6 ; sys_platform == 'win32'\ncontourpy==1.3.1\ncycler==0.12.1\nfilelock==3.17.0\nfonttools==4.56.0\nfrozenlist==1.5.0\nfsspec==2025.2.0\ngrpcio==1.70.0\nh5py==3.13.0\nidna==3.10\njinja2==3.1.5\nkiwisolver==1.4.8\nlightning==2.5.0.post0\nlightning-utilities==0.12.0\nmarkdown==3.7\nmarkupsafe==3.0.2\nmatplotlib==3.10.0\nmpmath==1.3.0\nmultidict==6.1.0\nnetworkx==3.4.2\nnumpy==2.2.3\nnvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\nnvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'\npackaging==24.2\npillow==11.1.0\npropcache==0.3.0\nprotobuf==5.29.3\npyparsing==3.2.1\npython-dateutil==2.9.0.post0\npytorch-lightning==2.5.0.post0\npyyaml==6.0.2\nsetuptools==75.8.0\nsix==1.17.0\nsympy==1.13.1\ntensorboard==2.19.0\ntensorboard-data-server==0.7.2\ntorch==2.6.0\ntorchmetrics==1.6.1\ntorchsummary==1.5.1\ntqdm==4.67.1\ntriton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'\ntyping-extensions==4.12.2\nwerkzeug==3.1.3\nyarl==1.18.3\n"
  },
  {
    "path": "root.py",
    "content": "from pathlib import Path\n\nROOT_DIR = Path(__file__).parent\n"
  },
  {
    "path": "train_SmaAtUNet.py",
    "content": "from typing import Optional\n\nfrom models.SmaAt_UNet import SmaAt_UNet\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torch import optim\nfrom torch import nn\nfrom torchvision import transforms\n\nfrom root import ROOT_DIR\nfrom utils import dataset_VOC\nimport time\nfrom tqdm import tqdm\nfrom metric import iou\nimport os\n\n\ndef get_lr(optimizer):\n    for param_group in optimizer.param_groups:\n        return param_group[\"lr\"]\n\n\ndef fit(\n    epochs,\n    model,\n    loss_func,\n    opt,\n    train_dl,\n    valid_dl,\n    dev=None,\n    save_every: Optional[int] = None,\n    tensorboard: bool = False,\n    earlystopping=None,\n    lr_scheduler=None,\n):\n    writer = None\n    if tensorboard:\n        from torch.utils.tensorboard import SummaryWriter\n\n        writer = SummaryWriter(comment=f\"{model.__class__.__name__}\")\n\n    if dev is None:\n        dev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    start_time = time.time()\n    best_mIoU = -1.0\n    earlystopping_counter = 0\n    for epoch in tqdm(range(epochs), desc=\"Epochs\", leave=True):\n        model.train()\n        train_loss = 0.0\n        for _, (xb, yb) in enumerate(tqdm(train_dl, desc=\"Batches\", leave=False)):\n            # for i, (xb, yb) in enumerate(train_dl):\n            loss = loss_func(model(xb.to(dev)), yb.to(dev))\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n            train_loss += loss.item()\n            # if i > 100:\n            #     break\n        train_loss /= len(train_dl)\n\n        # Reduce learning rate after epoch\n        # scheduler.step()\n\n        # Calc validation loss\n        val_loss = 0.0\n        iou_metric = iou.IoU(21, normalized=False)\n        model.eval()\n        with torch.no_grad():\n            for xb, yb in tqdm(valid_dl, desc=\"Validation\", leave=False):\n                # for xb, yb in valid_dl:\n                y_pred = model(xb.to(dev))\n                loss = loss_func(y_pred, yb.to(dev))\n                val_loss += loss.item()\n                # Calculate mean IOU\n                pred_class = torch.argmax(nn.functional.softmax(y_pred, dim=1), dim=1)\n                iou_metric.add(pred_class, target=yb)\n\n            iou_class, mean_iou = iou_metric.value()\n            val_loss /= len(valid_dl)\n\n        # Save the model with the best mean IoU\n        if mean_iou > best_mIoU:\n            os.makedirs(\"checkpoints\", exist_ok=True)\n            torch.save(\n                {\n                    \"model\": model,\n                    \"epoch\": epoch,\n                    \"state_dict\": model.state_dict(),\n                    \"optimizer_state_dict\": opt.state_dict(),\n                    \"val_loss\": val_loss,\n                    \"train_loss\": train_loss,\n                    \"mIOU\": mean_iou,\n                },\n                ROOT_DIR / \"checkpoints\" / f\"best_mIoU_model_{model.__class__.__name__}.pt\",\n            )\n            best_mIoU = mean_iou\n            earlystopping_counter = 0\n\n        else:\n            earlystopping_counter += 1\n            if earlystopping is not None and earlystopping_counter >= earlystopping:\n                print(f\"Stopping early --> mean IoU has not decreased over {earlystopping} epochs\")\n                break\n\n        print(\n            f\"Epoch: {epoch:5d}, Time: {(time.time() - start_time) / 60:.3f} min,\"\n            f\"Train_loss: {train_loss:2.10f}, Val_loss: {val_loss:2.10f},\",\n            f\"mIOU: {mean_iou:.10f},\",\n            f\"lr: {get_lr(opt)},\",\n            f\"Early stopping counter: {earlystopping_counter}/{earlystopping}\" if earlystopping is not None else \"\",\n        )\n\n        if writer:\n            # add to tensorboard\n            writer.add_scalar(\"Loss/train\", train_loss, epoch)\n            writer.add_scalar(\"Loss/val\", val_loss, epoch)\n            writer.add_scalar(\"Metric/mIOU\", mean_iou, epoch)\n            writer.add_scalar(\"Parameters/learning_rate\", get_lr(opt), epoch)\n        if save_every is not None and epoch % save_every == 0:\n            # save model\n            torch.save(\n                {\n                    \"model\": model,\n                    \"epoch\": epoch,\n                    \"state_dict\": model.state_dict(),\n                    \"optimizer_state_dict\": opt.state_dict(),\n                    # 'scheduler_state_dict': scheduler.state_dict(),\n                    \"val_loss\": val_loss,\n                    \"train_loss\": train_loss,\n                    \"mIOU\": mean_iou,\n                },\n                ROOT_DIR / \"checkpoints\" / f\"model_{model.__class__.__name__}_epoch_{epoch}.pt\",\n            )\n        if lr_scheduler is not None:\n            lr_scheduler.step(mean_iou)\n\n\nif __name__ == \"__main__\":\n    dev = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    dataset_folder = ROOT_DIR / \"data\" / \"VOCdevkit\"\n    batch_size = 8\n    learning_rate = 0.001\n    epochs = 200\n    earlystopping = 30\n    save_every = 1\n\n    # Load your dataset here\n    transformations = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])\n    voc_dataset_train = dataset_VOC.VOCSegmentation(\n        root=dataset_folder,\n        image_set=\"train\",\n        transformations=transformations,\n        augmentations=True,\n    )\n    voc_dataset_val = dataset_VOC.VOCSegmentation(\n        root=dataset_folder,\n        image_set=\"val\",\n        transformations=transformations,\n        augmentations=False,\n    )\n    train_dl = DataLoader(\n        dataset=voc_dataset_train,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=0,\n        pin_memory=True,\n    )\n    valid_dl = DataLoader(\n        dataset=voc_dataset_val,\n        batch_size=batch_size,\n        shuffle=False,\n        num_workers=0,\n        pin_memory=True,\n    )\n\n    # Load SmaAt-UNet\n    model = SmaAt_UNet(n_channels=3, n_classes=21)\n    # Move model to device\n    model.to(dev)\n    # Define Optimizer and loss\n    opt = optim.Adam(model.parameters(), lr=learning_rate)\n    loss_func = nn.CrossEntropyLoss().to(dev)\n\n    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode=\"max\", factor=0.1, patience=4)\n    # Train network\n    fit(\n        epochs=epochs,\n        model=model,\n        loss_func=loss_func,\n        opt=opt,\n        train_dl=train_dl,\n        valid_dl=valid_dl,\n        dev=dev,\n        save_every=save_every,\n        tensorboard=True,\n        earlystopping=earlystopping,\n        lr_scheduler=lr_scheduler,\n    )\n"
  },
  {
    "path": "train_precip_lightning.py",
    "content": "from root import ROOT_DIR\n\nimport lightning.pytorch as pl\nfrom lightning.pytorch.callbacks import (\n    ModelCheckpoint,\n    LearningRateMonitor,\n    EarlyStopping,\n)\nfrom lightning.pytorch import loggers\nimport argparse\nfrom models import unet_precip_regression_lightning as unet_regr\nfrom lightning.pytorch.tuner import Tuner\n\n\ndef train_regression(hparams, find_batch_size_automatically: bool = False):\n    if hparams.model == \"UNetDSAttention\":\n        net = unet_regr.UNetDSAttention(hparams=hparams)\n    elif hparams.model == \"UNetAttention\":\n        net = unet_regr.UNetAttention(hparams=hparams)\n    elif hparams.model == \"UNet\":\n        net = unet_regr.UNet(hparams=hparams)\n    elif hparams.model == \"UNetDS\":\n        net = unet_regr.UNetDS(hparams=hparams)\n    else:\n        raise NotImplementedError(f\"Model '{hparams.model}' not implemented\")\n\n    default_save_path = ROOT_DIR / \"lightning\" / \"precip_regression\"\n\n    checkpoint_callback = ModelCheckpoint(\n        dirpath=default_save_path / net.__class__.__name__,\n        filename=net.__class__.__name__ + \"_rain_threshold_50_{epoch}-{val_loss:.6f}\",\n        save_top_k=1,\n        verbose=False,\n        monitor=\"val_loss\",\n        mode=\"min\",\n    )\n\n    last_checkpoint_callback = ModelCheckpoint(\n        dirpath=default_save_path / net.__class__.__name__,\n        filename=net.__class__.__name__ + \"_rain_threshold_50_{epoch}-{val_loss:.6f}_last\",\n        save_top_k=1,\n        verbose=False,\n    )\n    \n    lr_monitor = LearningRateMonitor()\n    tb_logger = loggers.TensorBoardLogger(save_dir=default_save_path, name=net.__class__.__name__)\n\n    earlystopping_callback = EarlyStopping(\n        monitor=\"val_loss\",\n        mode=\"min\",\n        patience=hparams.es_patience,\n    )\n    trainer = pl.Trainer(\n        accelerator=\"gpu\",\n        devices=1,\n        fast_dev_run=hparams.fast_dev_run,\n        max_epochs=hparams.epochs,\n        default_root_dir=default_save_path,\n        logger=tb_logger,\n        callbacks=[checkpoint_callback, last_checkpoint_callback, earlystopping_callback, lr_monitor],\n        val_check_interval=hparams.val_check_interval,\n    )\n\n    if find_batch_size_automatically:\n        tuner = Tuner(trainer)\n\n        # Auto-scale batch size by growing it exponentially (default)\n        tuner.scale_batch_size(net, mode=\"binsearch\")\n\n    # This can be used to speed up training with newer GPUs:\n    # https://lightning.ai/docs/pytorch/stable/advanced/speed.html#low-precision-matrix-multiplication\n    # torch.set_float32_matmul_precision('medium')\n\n    trainer.fit(model=net, ckpt_path=hparams.resume_from_checkpoint)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    parser = unet_regr.PrecipRegressionBase.add_model_specific_args(parser)\n\n    parser.add_argument(\n        \"--dataset_folder\",\n        default=ROOT_DIR / \"data\" / \"precipitation\" / \"RAD_NL25_RAC_5min_train_test_2016-2019.h5\",\n        type=str,\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=16)\n    parser.add_argument(\"--learning_rate\", type=float, default=0.001)\n    parser.add_argument(\"--epochs\", type=int, default=200)\n    parser.add_argument(\"--fast_dev_run\", type=bool, default=False)\n    parser.add_argument(\"--resume_from_checkpoint\", type=str, default=None)\n    parser.add_argument(\"--val_check_interval\", type=float, default=None)\n\n    args = parser.parse_args()\n\n    # args.fast_dev_run = True\n    args.n_channels = 12\n    # args.gpus = 1\n    args.model = \"UNetDSAttention\"\n    args.lr_patience = 4\n    args.es_patience = 15\n    # args.val_check_interval = 0.25\n    args.kernels_per_layer = 2\n    args.use_oversampled_dataset = True\n    args.dataset_folder = (\n        ROOT_DIR / \"data\" / \"precipitation\" / \"train_test_2016-2019_input-length_12_img-ahead_6_rain-threshold_50.h5\"\n    )\n    # args.resume_from_checkpoint = f\"lightning/precip_regression/{args.model}/UNetDSAttention.ckpt\"\n\n    # train_regression(args, find_batch_size_automatically=False)\n\n    # All the models below will be trained\n    for m in [\"UNet\", \"UNetDS\", \"UNetAttention\", \"UNetDSAttention\"]:\n        args.model = m\n        print(f\"Start training model: {m}\")\n        train_regression(args, find_batch_size_automatically=False)"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/data_loader_precip.py",
    "content": "import torch\nimport numpy as np\nfrom torchvision import transforms\nfrom torch.utils.data.sampler import SubsetRandomSampler\nfrom utils import dataset_precip\nfrom root import ROOT_DIR\n\n\n# Taken from: https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb\ndef get_train_valid_loader(\n    data_dir,\n    batch_size,\n    random_seed,\n    num_input_images,\n    num_output_images,\n    augment,\n    classification,\n    valid_size=0.1,\n    shuffle=True,\n    num_workers=4,\n    pin_memory=False,\n):\n    \"\"\"\n    Utility function for loading and returning train and valid\n    multi-process iterators over the CIFAR-10 dataset. A sample\n    9x9 grid of the images can be optionally displayed.\n    If using CUDA, num_workers should be set to 1 and pin_memory to True.\n    Params\n    ------\n    - data_dir: path directory to the dataset.\n    - batch_size: how many samples per batch to load.\n    - augment: whether to apply the data augmentation scheme\n      mentioned in the paper. Only applied on the train split.\n    - random_seed: fix seed for reproducibility.\n    - valid_size: percentage split of the training set used for\n      the validation set. Should be a float in the range [0, 1].\n    - shuffle: whether to shuffle the train/validation indices.\n    - num_workers: number of subprocesses to use when loading the dataset.\n    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to\n      True if using GPU.\n    Returns\n    -------\n    - train_loader: training set iterator.\n    - valid_loader: validation set iterator.\n    \"\"\"\n    error_msg = \"[!] valid_size should be in the range [0, 1].\"\n    assert (valid_size >= 0) and (valid_size <= 1), error_msg\n\n    # Since I am not dealing with RGB images I do not need this\n    # normalize = transforms.Normalize(\n    #     mean=[0.4914, 0.4822, 0.4465],\n    #     std=[0.2023, 0.1994, 0.2010],\n    # )\n\n    # define transforms\n    valid_transform = None\n    # valid_transform = transforms.Compose([\n    #         transforms.ToTensor(),\n    #         normalize,\n    # ])\n    if augment:\n        # TODO flipping, rotating, sequence flipping (torch.flip seems very expensive)\n        train_transform = transforms.Compose(\n            [\n                transforms.RandomHorizontalFlip(),\n                # transforms.ToTensor(),\n                # normalize,\n            ]\n        )\n    else:\n        train_transform = None\n        # train_transform = transforms.Compose([\n        #     transforms.ToTensor(),\n        #     normalize,\n        # ])\n    if classification:\n        # load the dataset\n        train_dataset = dataset_precip.precipitation_maps_classification_h5(\n            in_file=data_dir,\n            num_input_images=num_input_images,\n            img_to_predict=num_output_images,\n            train=True,\n            transform=train_transform,\n        )\n\n        valid_dataset = dataset_precip.precipitation_maps_classification_h5(\n            in_file=data_dir,\n            num_input_images=num_input_images,\n            img_to_predict=num_output_images,\n            train=True,\n            transform=valid_transform,\n        )\n    else:\n        # load the dataset\n        train_dataset = dataset_precip.precipitation_maps_h5(\n            in_file=data_dir,\n            num_input_images=num_input_images,\n            num_output_images=num_output_images,\n            train=True,\n            transform=train_transform,\n        )\n\n        valid_dataset = dataset_precip.precipitation_maps_h5(\n            in_file=data_dir,\n            num_input_images=num_input_images,\n            num_output_images=num_output_images,\n            train=True,\n            transform=valid_transform,\n        )\n\n    num_train = len(train_dataset)\n    indices = list(range(num_train))\n    split = int(np.floor(valid_size * num_train))\n\n    if shuffle:\n        np.random.seed(random_seed)\n        np.random.shuffle(indices)\n\n    train_idx, valid_idx = indices[split:], indices[:split]\n    train_sampler = SubsetRandomSampler(train_idx)\n    valid_sampler = SubsetRandomSampler(valid_idx)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=batch_size,\n        sampler=train_sampler,\n        num_workers=num_workers,\n        pin_memory=pin_memory,\n    )\n    valid_loader = torch.utils.data.DataLoader(\n        valid_dataset,\n        batch_size=batch_size,\n        sampler=valid_sampler,\n        num_workers=num_workers,\n        pin_memory=pin_memory,\n    )\n\n    return train_loader, valid_loader\n\n\ndef get_test_loader(\n    data_dir,\n    batch_size,\n    num_input_images,\n    num_output_images,\n    classification,\n    shuffle=False,\n    num_workers=4,\n    pin_memory=False,\n):\n    \"\"\"\n    Utility function for loading and returning a multi-process\n    test iterator over the CIFAR-10 dataset.\n    If using CUDA, num_workers should be set to 1 and pin_memory to True.\n    Params\n    ------\n    - data_dir: path directory to the dataset.\n    - batch_size: how many samples per batch to load.\n    - shuffle: whether to shuffle the dataset after every epoch.\n    - num_workers: number of subprocesses to use when loading the dataset.\n    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to\n      True if using GPU.\n    Returns\n    -------\n    - data_loader: test set iterator.\n    \"\"\"\n    # Since I am not dealing with RGB images I do not need this\n    # normalize = transforms.Normalize(\n    #     mean=[0.485, 0.456, 0.406],\n    #     std=[0.229, 0.224, 0.225],\n    # )\n\n    # define transform\n    transform = None\n    # transform = transforms.Compose([\n    #     transforms.ToTensor(),\n    #     # normalize,\n    # ])\n    if classification:\n        dataset = dataset_precip.precipitation_maps_classification_h5(\n            in_file=data_dir,\n            num_input_images=num_input_images,\n            img_to_predict=num_output_images,\n            train=False,\n            transform=transform,\n        )\n    else:\n        dataset = dataset_precip.precipitation_maps_h5(\n            in_file=data_dir,\n            num_input_images=num_input_images,\n            num_output_images=num_output_images,\n            train=False,\n            transform=transform,\n        )\n\n    data_loader = torch.utils.data.DataLoader(\n        dataset,\n        batch_size=batch_size,\n        shuffle=shuffle,\n        num_workers=num_workers,\n        pin_memory=pin_memory,\n    )\n\n    return data_loader\n\n\nif __name__ == \"__main__\":\n    folder = ROOT_DIR / \"data\" / \"precipitation\"\n    data = \"RAD_NL25_RAC_5min_train_test_2016-2019.h5\"\n    train_dl, valid_dl = get_train_valid_loader(\n        folder / data,\n        batch_size=8,\n        random_seed=1337,\n        num_input_images=12,\n        num_output_images=6,\n        classification=True,\n        augment=False,\n        valid_size=0.1,\n        shuffle=True,\n        num_workers=4,\n        pin_memory=False,\n    )\n    for xb, yb in train_dl:\n        print(\"xb.shape: \", xb.shape)\n        print(\"yb.shape: \", yb.shape)\n        break\n"
  },
  {
    "path": "utils/dataset_VOC.py",
    "content": "import random\nimport torch\nfrom torch.utils.data import Dataset\nfrom PIL import Image\nfrom pathlib import Path\nfrom torchvision import transforms\nimport torchvision.transforms.functional as TF\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n\ndef get_pascal_labels():\n    \"\"\"Load the mapping that associates pascal classes with label colors\n    Returns:\n        np.ndarray with dimensions (21, 3)\n    \"\"\"\n    return np.asarray(\n        [\n            [0, 0, 0],\n            [128, 0, 0],\n            [0, 128, 0],\n            [128, 128, 0],\n            [0, 0, 128],\n            [128, 0, 128],\n            [0, 128, 128],\n            [128, 128, 128],\n            [64, 0, 0],\n            [192, 0, 0],\n            [64, 128, 0],\n            [192, 128, 0],\n            [64, 0, 128],\n            [192, 0, 128],\n            [64, 128, 128],\n            [192, 128, 128],\n            [0, 64, 0],\n            [128, 64, 0],\n            [0, 192, 0],\n            [128, 192, 0],\n            [0, 64, 128],\n        ]\n    )\n\n\ndef decode_segmap(label_mask, plot=False):\n    \"\"\"Decode segmentation class labels into a color image\n    Args:\n        label_mask (np.ndarray): an (M,N) array of integer values denoting\n          the class label at each spatial location.\n        plot (bool, optional): whether to show the resulting color image\n          in a figure.\n    Returns:\n        (np.ndarray, optional): the resulting decoded color image.\n    \"\"\"\n    label_colours = get_pascal_labels()\n    r = label_mask.copy()\n    g = label_mask.copy()\n    b = label_mask.copy()\n    for ll in range(21):\n        r[label_mask == ll] = label_colours[ll, 0]\n        g[label_mask == ll] = label_colours[ll, 1]\n        b[label_mask == ll] = label_colours[ll, 2]\n    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))\n    rgb[:, :, 0] = r / 255.0\n    rgb[:, :, 1] = g / 255.0\n    rgb[:, :, 2] = b / 255.0\n    if plot:\n        plt.imshow(rgb)\n        plt.show()\n    else:\n        return rgb\n\n\nclass VOCSegmentation(Dataset):\n    CLASS_NAMES = [\n        \"background\",\n        \"aeroplane\",\n        \"bicycle\",\n        \"bird\",\n        \"boat\",\n        \"bottle\",\n        \"bus\",\n        \"car\",\n        \"cat\",\n        \"chair\",\n        \"cow\",\n        \"diningtable\",\n        \"dog\",\n        \"horse\",\n        \"motorbike\",\n        \"person\",\n        \"potted-plant\",\n        \"sheep\",\n        \"sofa\",\n        \"train\",\n        \"tv/monitor\",\n    ]\n\n    def __init__(self, root: Path, image_set=\"train\", transformations=None, augmentations=False):\n        super().__init__()\n        assert image_set in [\"train\", \"val\", \"trainval\"]\n\n        voc_root = root / \"VOC2012\"\n        image_dir = voc_root / \"JPEGImages\"\n        mask_dir = voc_root / \"SegmentationClass\"\n\n        splits_dir = voc_root / \"ImageSets\" / \"Segmentation\"\n\n        split_f = splits_dir / (image_set.rstrip(\"\\n\") + \".txt\")\n        with open(split_f) as f:\n            file_names = [x.strip() for x in f.readlines()]\n\n        self.images = [image_dir / (x + \".jpg\") for x in file_names]\n        self.masks = [mask_dir / (x + \".png\") for x in file_names]\n        assert len(self.images) == len(self.masks)\n\n        self.transformations = transformations\n        self.augmentations = augmentations\n\n    def __getitem__(self, index):\n        img = Image.open(self.images[index]).convert(\"RGB\")\n        target = Image.open(self.masks[index])\n\n        if self.transformations is not None:\n            img = self.transformations(img)\n            target = self.transformations(target)\n        # Apply augmentations to both input and target\n        if self.augmentations:\n            img, target = self.apply_augmentations(img, target)\n\n        # Convert the RGB image to a tensor\n        toTensorTransform = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize(\n                    mean=[0.485, 0.456, 0.406],\n                    std=[0.229, 0.224, 0.225],\n                ),\n            ]\n        )\n        img = toTensorTransform(img)\n        # Convert target to long tensor\n        target = torch.from_numpy(np.array(target)).long()\n        target[target == 255] = 0\n\n        return img, target\n\n    def __len__(self):\n        return len(self.images)\n\n    def apply_augmentations(self, img, target):\n        # Horizontal flip\n        if random.random() > 0.5:\n            img = TF.hflip(img)\n            target = TF.hflip(target)\n        # Random Rotation (clockwise and counterclockwise)\n        if random.random() > 0.5:\n            degrees = 10\n            if random.random() > 0.5:\n                degrees *= -1\n            img = TF.rotate(img, degrees)\n            target = TF.rotate(target, degrees)\n        # Brighten or darken image (only applied to input image)\n        if random.random() > 0.5:\n            brightness = 1.2\n            if random.random() > 0.5:\n                brightness -= 0.4\n            img = TF.adjust_brightness(img, brightness)\n        return img, target\n"
  },
  {
    "path": "utils/dataset_precip.py",
    "content": "from torch.utils.data import Dataset\nimport h5py\nimport numpy as np\n\n\nclass precipitation_maps_h5(Dataset):\n    def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None):\n        super().__init__()\n\n        self.file_name = in_file\n        self.n_images, self.nx, self.ny = h5py.File(self.file_name, \"r\")[\"train\" if train else \"test\"][\"images\"].shape\n\n        self.num_input = num_input_images\n        self.num_output = num_output_images\n        self.sequence_length = num_input_images + num_output_images\n\n        self.train = train\n        # Dataset is all the images\n        self.size_dataset = self.n_images - (num_input_images + num_output_images)\n        # self.size_dataset = int(self.n_images/(num_input_images+num_output_images))\n        self.transform = transform\n        self.dataset = None\n\n    def __getitem__(self, index):\n        # min_feature_range = 0.0\n        # max_feature_range = 1.0\n        # with h5py.File(self.file_name, 'r') as dataFile:\n        #     dataset = dataFile[\"train\" if self.train else \"test\"]['images'][index:index+self.sequence_length]\n        # load the file here (load as singleton)\n        if self.dataset is None:\n            self.dataset = h5py.File(self.file_name, \"r\", rdcc_nbytes=1024**3)[\"train\" if self.train else \"test\"][\n                \"images\"\n            ]\n        imgs = np.array(self.dataset[index : index + self.sequence_length], dtype=\"float32\")\n\n        # add transforms\n        if self.transform is not None:\n            imgs = self.transform(imgs)\n        input_img = imgs[: self.num_input]\n        target_img = imgs[-1]\n\n        return input_img, target_img\n\n    def __len__(self):\n        return self.size_dataset\n\n\nclass precipitation_maps_oversampled_h5(Dataset):\n    def __init__(self, in_file, num_input_images, num_output_images, train=True, transform=None):\n        super().__init__()\n\n        self.file_name = in_file\n        self.samples, _, _, _ = h5py.File(self.file_name, \"r\")[\"train\" if train else \"test\"][\"images\"].shape\n\n        self.num_input = num_input_images\n        self.num_output = num_output_images\n\n        self.train = train\n        # self.size_dataset = int(self.n_images/(num_input_images+num_output_images))\n        self.transform = transform\n        self.dataset = None\n\n    def __getitem__(self, index):\n        # load the file here (load as singleton)\n        if self.dataset is None:\n            self.dataset = h5py.File(self.file_name, \"r\", rdcc_nbytes=1024**3)[\"train\" if self.train else \"test\"][\n                \"images\"\n            ]\n        imgs = np.array(self.dataset[index], dtype=\"float32\")\n\n        # add transforms\n        if self.transform is not None:\n            imgs = self.transform(imgs)\n        input_img = imgs[: self.num_input]\n        target_img = imgs[-1]\n\n        return input_img, target_img\n\n    def __len__(self):\n        return self.samples\n\n\nclass precipitation_maps_classification_h5(Dataset):\n    def __init__(self, in_file, num_input_images, img_to_predict, train=True, transform=None):\n        super().__init__()\n\n        self.file_name = in_file\n        self.n_images, self.nx, self.ny = h5py.File(self.file_name, \"r\")[\"train\" if train else \"test\"][\"images\"].shape\n\n        self.num_input = num_input_images\n        self.img_to_predict = img_to_predict\n        self.sequence_length = num_input_images + img_to_predict\n        self.bins = np.array([0.0, 0.5, 1, 2, 5, 10, 30])\n\n        self.train = train\n        # Dataset is all the images\n        self.size_dataset = self.n_images - (num_input_images + img_to_predict)\n        # self.size_dataset = int(self.n_images/(num_input_images+num_output_images))\n        self.transform = transform\n        self.dataset = None\n\n    def __getitem__(self, index):\n        # min_feature_range = 0.0\n        # max_feature_range = 1.0\n        # with h5py.File(self.file_name, 'r') as dataFile:\n        #     dataset = dataFile[\"train\" if self.train else \"test\"]['images'][index:index+self.sequence_length]\n        # load the file here (load as singleton)\n        if self.dataset is None:\n            self.dataset = h5py.File(self.file_name, \"r\", rdcc_nbytes=1024**3)[\"train\" if self.train else \"test\"][\n                \"images\"\n            ]\n        imgs = np.array(self.dataset[index : index + self.sequence_length], dtype=\"float32\")\n\n        # add transforms\n        if self.transform is not None:\n            imgs = self.transform(imgs)\n        input_img = imgs[: self.num_input]\n        # put the img in buckets\n        target_img = imgs[-1]\n        # target_img is normalized by dividing through the highest value of the training set. We reverse this.\n        # Then target_img is in mm/5min. The bins have the unit mm/hour. Therefore we multiply the img by 12\n        buckets = np.digitize(target_img * 47.83 * 12, self.bins, right=True)\n\n        return input_img, buckets\n\n    def __len__(self):\n        return self.size_dataset\n"
  },
  {
    "path": "utils/formatting.py",
    "content": "import torch\nimport numpy as np\n\n\ndef make_metrics_str(metrics_dict):\n    return \" | \".join([f\"{name}: {value.item() if isinstance(value, torch.Tensor) else value:.4f}\" \n                             for name, value in metrics_dict.items() \n                             if not (isinstance(value, torch.Tensor) and torch.isnan(value)) and \n                                not (not isinstance(value, torch.Tensor) and np.isnan(value))])\n"
  },
  {
    "path": "utils/model_classes.py",
    "content": "from models import unet_precip_regression_lightning as unet_regr\nimport lightning.pytorch as pl\n\n\ndef get_model_class(model_file) -> tuple[type[pl.LightningModule], str]:\n    # This is for some nice plotting\n    if \"UNetAttention\" in model_file:\n        model_name = \"UNet Attention\"\n        model = unet_regr.UNetAttention\n    elif \"UNetDSAttention4kpl\" in model_file:\n        model_name = \"UNetDS Attention with 4kpl\"\n        model = unet_regr.UNetDSAttention\n    elif \"UNetDSAttention1kpl\" in model_file:\n        model_name = \"UNetDS Attention with 1kpl\"\n        model = unet_regr.UNetDSAttention\n    elif \"UNetDSAttention4CBAMs\" in model_file:\n        model_name = \"UNetDS Attention 4CBAMs\"\n        model = unet_regr.UNetDSAttention4CBAMs\n    elif \"UNetDSAttention\" in model_file:\n        model_name = \"SmaAt-UNet\"\n        model = unet_regr.UNetDSAttention\n    elif \"UNetDS\" in model_file:\n        model_name = \"UNetDS\"\n        model = unet_regr.UNetDS\n    elif \"UNet\" in model_file:\n        model_name = \"UNet\"\n        model = unet_regr.UNet\n    elif \"PersistenceModel\" in model_file:\n        model_name = \"PersistenceModel\"\n        model = unet_regr.PersistenceModel\n    else:\n        raise NotImplementedError(\"Model not found\")\n    return model, model_name\n"
  }
]