[
  {
    "path": ".gitignore",
    "content": "*.DS_STORE\n"
  },
  {
    "path": "README.md",
    "content": "# Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks (ICLR 2022 - [open review](https://openreview.net/forum?id=kOu3-S3wJ7) - [pdf](https://openreview.net/pdf?id=kOu3-S3wJ7))\n\n[![ICLR](https://img.shields.io/badge/ICLR-2022-blue.svg?style=flat-square)](https://openreview.net/forum?id=kOu3-S3wJ7)\n[![PDF](https://img.shields.io/badge/%E2%87%A9-PDF-orange.svg?style=flat-square)](https://openreview.net/pdf?id=kOu3-S3wJ7)\n[![arXiv](https://img.shields.io/badge/arXiv-2108.00298-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2108.00298)\n\nThis repository contains the code for the reproducibility of the experiments presented in the paper \"Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks\" (ICLR 2022). In this paper, we propose a graph neural network architecture for multivariate time series imputation and achieve state-of-the-art results on several benchmarks.\n\n**Authors**: [Andrea Cini](mailto:andrea.cini@usi.ch), [Ivan Marisca](mailto:ivan.marisca@usi.ch), Cesare Alippi\n\n\n**‼️ PyG implementation of GRIN is now available inside [Torch Spatiotemporal](https://github.com/TorchSpatiotemporal/tsl), a library built to accelerate research on neural spatiotemporal data processing methods, with a focus on Graph Neural Networks.**\n\n---\n\n<h2 align=center>GRIN in a nutshell</h2>\n\nThe [paper](https://arxiv.org/abs/2108.00298) introduces __GRIN__, a method and an architecture to exploit relational inductive biases to reconstruct missing values in multivariate time series coming from sensor networks. GRIN features a bidirectional recurrent GNN which learns __spatio-temporal node-level representations__ tailored to reconstruct observations at neighboring nodes.\n\n<p align=center>\n  <a href=\"https://github.com/marshka/sinfony\">\n    <img src=\"./grin.png\" alt=\"Logo\"/>\n  </a>\n</p>\n\n---\n\n## Directory structure\n\nThe directory is structured as follows:\n\n```\n.\n├── config\n│   ├── bimpgru\n│   ├── brits\n│   ├── grin\n│   ├── mpgru\n│   ├── rgain\n│   └── var\n├── datasets\n│   ├── air_quality\n│   ├── metr_la\n│   ├── pems_bay\n│   └── synthetic\n├── lib\n│   ├── __init__.py\n│   ├── data\n│   ├── datasets\n│   ├── fillers\n│   ├── nn\n│   └── utils\n├── requirements.txt\n└── scripts\n    ├── run_baselines.py\n    ├── run_imputation.py\n    └── run_synthetic.py\n\n```\nNote that, given the size of the files, the datasets are not readily available in the folder. See the next section for the downloading instructions.\n\n## Datasets\n\nAll the datasets used in the experiment, except CER-E, are open and can be downloaded from this [link](https://mega.nz/folder/qwwG3Qba#c6qFTeT7apmZKKyEunCzSg). The CER-E dataset can be obtained free of charge for research purposes following the instructions at this [link](https://www.ucd.ie/issda/data/commissionforenergyregulationcer/). We recommend storing the downloaded datasets in a folder named `datasets` inside this directory.\n\n## Configuration files\n\nThe `config` directory stores all the configuration files used to run the experiment. They are divided into folders, according to the model.\n\n## Library\n\nThe support code, including the models and the datasets readers, are packed in a python library named `lib`. Should you have to change the paths to the datasets location, you have to edit the `__init__.py` file of the library.\n\n## Scripts\n\nThe scripts used for the experiment in the paper are in the `scripts` folder.\n\n* `run_baselines.py` is used to compute the metrics for the `MEAN`, `KNN`, `MF` and `MICE` imputation methods. An example of usage is\n\n\t```\n\tpython ./scripts/run_baselines.py --datasets air36 air --imputers mean knn --k 10 --in-sample True --n-runs 5\n\t```\n\n* `run_imputation.py` is used to compute the metrics for the deep imputation methods. An example of usage is\n\n\t```\n\tpython ./scripts/run_imputation.py --config config/grin/air36.yaml --in-sample False\n\t```\n\n* `run_synthetic.py` is used for the experiments on the synthetic datasets. An example of usage is\n\n\t```\n\tpython ./scripts/run_synthetic.py --config config/grin/synthetic.yaml --static-adj False\n\t```\n\n## Requirements\n\nWe run all the experiments in `python 3.8`, see `requirements.txt` for the list of `pip` dependencies.\n\n## Bibtex reference\n\nIf you find this code useful please consider to cite our paper:\n\n```\n@inproceedings{cini2022filling,\n    title={Filling the G\\_ap\\_s: Multivariate Time Series Imputation by Graph Neural Networks},\n    author={Andrea Cini and Ivan Marisca and Cesare Alippi},\n    booktitle={International Conference on Learning Representations},\n    year={2022},\n    url={https://openreview.net/forum?id=kOu3-S3wJ7}\n}\n```\n"
  },
  {
    "path": "conda_env.yml",
    "content": "name: grin\nchannels:\n  - defaults\n  - pytorch\n  - conda-forge\ndependencies:\n  - pip\n  - pytables\n  - python=3.8\n  - pytorch=1.8\n  - torchvision\n  - torchaudio\n  - wheel\n  - pip:\n      - einops\n      - fancyimpute==0.6\n      - h5py\n      - openpyxl\n      - pandas\n      - pytorch-lightning==1.4\n      - pyyaml\n      - scikit-learn\n      - scipy\n      - tensorboard\n"
  },
  {
    "path": "config/bimpgru/air.yaml",
    "content": "dataset_name: 'air'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/bimpgru/air36.yaml",
    "content": "dataset_name: 'air36'\nwindow: 36\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/bimpgru/bay_block.yaml",
    "content": "dataset_name: 'bay_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/bimpgru/bay_point.yaml",
    "content": "dataset_name: 'bay_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/bimpgru/irish_block.yaml",
    "content": "dataset_name: 'irish_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/bimpgru/irish_point.yaml",
    "content": "dataset_name: 'irish_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/bimpgru/la_block.yaml",
    "content": "dataset_name: 'la_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/bimpgru/la_point.yaml",
    "content": "dataset_name: 'la_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'bimpgru'\nd_hidden: 64\nd_emb: 8\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'"
  },
  {
    "path": "config/brits/air.yaml",
    "content": "dataset_name: 'air'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 128"
  },
  {
    "path": "config/brits/air36.yaml",
    "content": "dataset_name: 'air36'\nwindow: 36\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 64"
  },
  {
    "path": "config/brits/bay_block.yaml",
    "content": "dataset_name: 'bay_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 256"
  },
  {
    "path": "config/brits/bay_point.yaml",
    "content": "dataset_name: 'bay_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 256"
  },
  {
    "path": "config/brits/irish_block.yaml",
    "content": "dataset_name: 'irish_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 256"
  },
  {
    "path": "config/brits/irish_point.yaml",
    "content": "dataset_name: 'irish_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 256"
  },
  {
    "path": "config/brits/la_block.yaml",
    "content": "dataset_name: 'la_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 128"
  },
  {
    "path": "config/brits/la_point.yaml",
    "content": "dataset_name: 'la_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'brits'\nd_hidden: 128"
  },
  {
    "path": "config/brits/synthetic.yaml",
    "content": "window: 36\np_block: 0.025\np_point: 0.025\nmin_seq: 4\nmax_seq: 9\nuse_exogenous: False\n\nepochs: 200\nbatch_size: 32\n\nmodel_name: 'brits'\nd_hidden: 32"
  },
  {
    "path": "config/grin/air.yaml",
    "content": "dataset_name: 'air'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/air36.yaml",
    "content": "dataset_name: 'air36'\nwindow: 36\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/bay_block.yaml",
    "content": "dataset_name: 'bay_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/bay_point.yaml",
    "content": "dataset_name: 'bay_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/irish_block.yaml",
    "content": "dataset_name: 'irish_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/irish_point.yaml",
    "content": "dataset_name: 'irish_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/la_block.yaml",
    "content": "dataset_name: 'la_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/la_point.yaml",
    "content": "dataset_name: 'la_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'grin'\npred_loss_weight: 1\n\nd_hidden: 64\nd_emb: 8\nd_ff: 64\nff_dropout: 0\nkernel_size: 2\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/grin/synthetic.yaml",
    "content": "window: 36\np_block: 0.025\np_point: 0.025\nmin_seq: 4\nmax_seq: 9\nuse_exogenous: False\n\nepochs: 200\nbatch_size: 32\n\nmodel_name: 'grin'\nd_hidden: 16\nd_emb: 0\nd_ff: 16\nff_dropout: 0\nkernel_size: 1\ndecoder_order: 1\nn_layers: 1\nlayer_norm: false\nmerge: 'mlp'\n"
  },
  {
    "path": "config/mpgru/air.yaml",
    "content": "dataset_name: 'air'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/mpgru/air36.yaml",
    "content": "dataset_name: 'air36'\nwindow: 36\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/mpgru/bay_block.yaml",
    "content": "dataset_name: 'bay_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/mpgru/bay_point.yaml",
    "content": "dataset_name: 'bay_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/mpgru/irish_block.yaml",
    "content": "dataset_name: 'irish_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/mpgru/irish_point.yaml",
    "content": "dataset_name: 'irish_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/mpgru/la_block.yaml",
    "content": "dataset_name: 'la_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/mpgru/la_point.yaml",
    "content": "dataset_name: 'la_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'global']\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\naggregate_by: ['mean']\n\nmodel_name: 'mpgru'\npred_loss_weight: 1\nd_hidden: 64\nd_ff: 64\ndropout: 0\nkernel_size: 2\nn_layers: 1\nlayer_norm: false"
  },
  {
    "path": "config/rgain/air.yaml",
    "content": "dataset_name: 'air'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 128\nd_z: 4\ndropout: 0.2\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/rgain/air36.yaml",
    "content": "dataset_name: 'air36'\nwindow: 36\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 64\nd_z: 4\ndropout: 0.1\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/rgain/bay_block.yaml",
    "content": "dataset_name: 'bay_block'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 256\nd_z: 4\ndropout: 0.2\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/rgain/bay_point.yaml",
    "content": "dataset_name: 'bay_point'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 256\nd_z: 4\ndropout: 0.2\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/rgain/irish_block.yaml",
    "content": "dataset_name: 'irish_block'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 256\nd_z: 4\ndropout: 0.2\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/rgain/irish_point.yaml",
    "content": "dataset_name: 'irish_point'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 256\nd_z: 4\ndropout: 0.2\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/rgain/la_block.yaml",
    "content": "dataset_name: 'la_block'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 128\nd_z: 4\ndropout: 0.2\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/rgain/la_point.yaml",
    "content": "dataset_name: 'la_point'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nbatch_size: 32\nloss_fn: mse_loss\nconsistency_loss: False\nuse_lr_schedule: True\ngrad_clip_val: -1\naggregate_by: ['mean']\n\nmodel_name: 'gain'\nd_model: 128\nd_z: 4\ndropout: 0.2\ninject_noise: true\nalpha: 20\ng_train_freq: 3\nd_train_freq: 1"
  },
  {
    "path": "config/var/air.yaml",
    "content": "dataset_name: 'air'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "config/var/air36.yaml",
    "content": "dataset_name: 'air36'\nwindow: 36\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "config/var/bay_block.yaml",
    "content": "dataset_name: 'bay_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "config/var/bay_point.yaml",
    "content": "dataset_name: 'bay_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "config/var/irish_block.yaml",
    "content": "dataset_name: 'irish_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "config/var/irish_point.yaml",
    "content": "dataset_name: 'irish_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "config/var/la_block.yaml",
    "content": "dataset_name: 'la_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "config/var/la_point.yaml",
    "content": "dataset_name: 'la_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsamples_per_epoch: 5120  # 160 batch of 32\nlr: 0.0005\nbatch_size: 64\naggregate_by: ['mean']\n\nmodel_name: 'var'\norder: 5\npadding: 'mean'"
  },
  {
    "path": "lib/__init__.py",
    "content": "import os\n\nbase_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))\n\nconfig = {\n    'logs': 'logs/'\n}\ndatasets_path = {\n    'air': 'datasets/air_quality',\n    'la': 'datasets/metr_la',\n    'bay': 'datasets/pems_bay',\n    'synthetic': 'datasets/synthetic'\n}\nepsilon = 1e-8\n\nfor k, v in config.items():\n    config[k] = os.path.join(base_dir, v)\nfor k, v in datasets_path.items():\n    datasets_path[k] = os.path.join(base_dir, v)\n"
  },
  {
    "path": "lib/data/__init__.py",
    "content": "from .temporal_dataset import TemporalDataset\nfrom .spatiotemporal_dataset import SpatioTemporalDataset\n"
  },
  {
    "path": "lib/data/datamodule/__init__.py",
    "content": "from .spatiotemporal import SpatioTemporalDataModule\n"
  },
  {
    "path": "lib/data/datamodule/spatiotemporal.py",
    "content": "import pytorch_lightning as pl\nfrom torch.utils.data import DataLoader, Subset, RandomSampler\n\nfrom .. import TemporalDataset, SpatioTemporalDataset\nfrom ..preprocessing import StandardScaler, MinMaxScaler\nfrom ...utils import ensure_list\nfrom ...utils.parser_utils import str_to_bool\n\n\nclass SpatioTemporalDataModule(pl.LightningDataModule):\n    \"\"\"\n    Pytorch Lightning DataModule for TimeSeriesDatasets\n    \"\"\"\n\n    def __init__(self, dataset: TemporalDataset,\n                 scale=True,\n                 scaling_axis='samples',\n                 scaling_type='std',\n                 scale_exogenous=None,\n                 train_idxs=None,\n                 val_idxs=None,\n                 test_idxs=None,\n                 batch_size=32,\n                 workers=1,\n                 samples_per_epoch=None):\n        super(SpatioTemporalDataModule, self).__init__()\n        self.torch_dataset = dataset\n        # splitting\n        self.trainset = Subset(self.torch_dataset, train_idxs if train_idxs is not None else [])\n        self.valset = Subset(self.torch_dataset, val_idxs if val_idxs is not None else [])\n        self.testset = Subset(self.torch_dataset, test_idxs if test_idxs is not None else [])\n        # preprocessing\n        self.scale = scale\n        self.scaling_type = scaling_type\n        self.scaling_axis = scaling_axis\n        self.scale_exogenous = ensure_list(scale_exogenous) if scale_exogenous is not None else None\n        # data loaders\n        self.batch_size = batch_size\n        self.workers = workers\n        self.samples_per_epoch = samples_per_epoch\n\n    @property\n    def is_spatial(self):\n        return isinstance(self.torch_dataset, SpatioTemporalDataset)\n\n    @property\n    def n_nodes(self):\n        if not self.has_setup_fit:\n            raise ValueError('You should initialize the datamodule first.')\n        return self.torch_dataset.n_nodes if self.is_spatial else None\n\n    @property\n    def d_in(self):\n        if not self.has_setup_fit:\n            raise ValueError('You should initialize the datamodule first.')\n        return self.torch_dataset.n_channels\n\n    @property\n    def d_out(self):\n        if not self.has_setup_fit:\n            raise ValueError('You should initialize the datamodule first.')\n        return self.torch_dataset.horizon\n\n    @property\n    def train_slice(self):\n        return self.torch_dataset.expand_indices(self.trainset.indices, merge=True)\n\n    @property\n    def val_slice(self):\n        return self.torch_dataset.expand_indices(self.valset.indices, merge=True)\n\n    @property\n    def test_slice(self):\n        return self.torch_dataset.expand_indices(self.testset.indices, merge=True)\n\n    def get_scaling_axes(self, dim='global'):\n        scaling_axis = tuple()\n        if dim == 'global':\n            scaling_axis = (0, 1, 2)\n        elif dim == 'channels':\n            scaling_axis = (0, 1)\n        elif dim == 'nodes':\n            scaling_axis = (0,)\n        # Remove last dimension for temporal datasets\n        if not self.is_spatial:\n            scaling_axis = scaling_axis[:-1]\n\n        if not len(scaling_axis):\n            raise ValueError(f'Scaling axis \"{dim}\" not valid.')\n\n        return scaling_axis\n\n    def get_scaler(self):\n        if self.scaling_type == 'std':\n            return StandardScaler\n        elif self.scaling_type == 'minmax':\n            return MinMaxScaler\n        else:\n            return NotImplementedError\n\n    def setup(self, stage=None):\n\n        if self.scale:\n            scaling_axis = self.get_scaling_axes(self.scaling_axis)\n            train = self.torch_dataset.data.numpy()[self.train_slice]\n            train_mask = self.torch_dataset.mask.numpy()[self.train_slice] if 'mask' in self.torch_dataset else None\n            scaler = self.get_scaler()(scaling_axis).fit(train, mask=train_mask, keepdims=True).to_torch()\n            self.torch_dataset.scaler = scaler\n\n            if self.scale_exogenous is not None:\n                for label in self.scale_exogenous:\n                    exo = getattr(self.torch_dataset, label)\n                    scaler = self.get_scaler()(scaling_axis)\n                    scaler.fit(exo[self.train_slice], keepdims=True).to_torch()\n                    setattr(self.torch_dataset, label, scaler.transform(exo))\n\n    def _data_loader(self, dataset, shuffle=False, batch_size=None, **kwargs):\n        batch_size = self.batch_size if batch_size is None else batch_size\n        return DataLoader(dataset,\n                          shuffle=shuffle,\n                          batch_size=batch_size,\n                          num_workers=self.workers,\n                          **kwargs)\n\n    def train_dataloader(self, shuffle=True, batch_size=None):\n        if self.samples_per_epoch is not None:\n            sampler = RandomSampler(self.trainset, replacement=True, num_samples=self.samples_per_epoch)\n            return self._data_loader(self.trainset, False, batch_size, sampler=sampler, drop_last=True)\n        return self._data_loader(self.trainset, shuffle, batch_size, drop_last=True)\n\n    def val_dataloader(self, shuffle=False, batch_size=None):\n        return self._data_loader(self.valset, shuffle, batch_size)\n\n    def test_dataloader(self, shuffle=False, batch_size=None):\n        return self._data_loader(self.testset, shuffle, batch_size)\n\n    @staticmethod\n    def add_argparse_args(parser, **kwargs):\n        parser.add_argument('--batch-size', type=int, default=64)\n        parser.add_argument('--scaling-axis', type=str, default=\"channels\")\n        parser.add_argument('--scaling-type', type=str, default=\"std\")\n        parser.add_argument('--scale', type=str_to_bool, nargs='?', const=True, default=True)\n        parser.add_argument('--workers', type=int, default=0)\n        parser.add_argument('--samples-per-epoch', type=int, default=None)\n        return parser\n"
  },
  {
    "path": "lib/data/imputation_dataset.py",
    "content": "import numpy as np\nimport torch\n\nfrom . import TemporalDataset, SpatioTemporalDataset\n\n\nclass ImputationDataset(TemporalDataset):\n\n    def __init__(self, data,\n                 index=None,\n                 mask=None,\n                 eval_mask=None,\n                 freq=None,\n                 trend=None,\n                 scaler=None,\n                 window=24,\n                 stride=1,\n                 exogenous=None):\n        if mask is None:\n            mask = np.ones_like(data)\n        if exogenous is None:\n            exogenous = dict()\n        exogenous['mask_window'] = mask\n        if eval_mask is not None:\n            exogenous['eval_mask_window'] = eval_mask\n        super(ImputationDataset, self).__init__(data,\n                                                index=index,\n                                                exogenous=exogenous,\n                                                trend=trend,\n                                                scaler=scaler,\n                                                freq=freq,\n                                                window=window,\n                                                horizon=window,\n                                                delay=-window,\n                                                stride=stride)\n\n    def get(self, item, preprocess=False):\n        res, transform = super(ImputationDataset, self).get(item, preprocess)\n        res['x'] = torch.where(res['mask'], res['x'], torch.zeros_like(res['x']))\n        return res, transform\n\n\nclass GraphImputationDataset(ImputationDataset, SpatioTemporalDataset):\n    pass\n"
  },
  {
    "path": "lib/data/preprocessing/__init__.py",
    "content": "from .scalers import *\n"
  },
  {
    "path": "lib/data/preprocessing/scalers.py",
    "content": "from abc import ABC, abstractmethod\nimport numpy as np\n\n\nclass AbstractScaler(ABC):\n\n    def __init__(self, **kwargs):\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n    def __repr__(self):\n        params = \", \".join([f\"{k}={str(v)}\" for k, v in self.params().items()])\n        return \"{}({})\".format(self.__class__.__name__, params)\n\n    def __call__(self, *args, **kwargs):\n        return self.transform(*args, **kwargs)\n\n    def params(self):\n        return {k: v for k, v in self.__dict__.items() if not callable(v) and not k.startswith(\"__\")}\n\n    @abstractmethod\n    def fit(self, x):\n        pass\n\n    @abstractmethod\n    def transform(self, x):\n        pass\n\n    @abstractmethod\n    def inverse_transform(self, x):\n        pass\n\n    def fit_transform(self, x):\n        self.fit(x)\n        return self.transform(x)\n\n    def to_torch(self):\n        import torch\n        for p in self.params():\n            param = getattr(self, p)\n            param = np.atleast_1d(param)\n            param = torch.tensor(param).float()\n            setattr(self, p, param)\n        return self\n\n\nclass Scaler(AbstractScaler):\n    def __init__(self, offset=0., scale=1.):\n        self.bias = offset\n        self.scale = scale\n        super(Scaler, self).__init__()\n\n    def params(self):\n        return dict(bias=self.bias, scale=self.scale)\n\n    def fit(self, x, mask=None, keepdims=True):\n        pass\n\n    def transform(self, x):\n        return (x - self.bias) / self.scale\n\n    def inverse_transform(self, x):\n        return x * self.scale + self.bias\n\n    def fit_transform(self, x, mask=None, keepdims=True):\n        self.fit(x, mask, keepdims)\n        return self.transform(x)\n\n\nclass StandardScaler(Scaler):\n    def __init__(self, axis=0):\n        self.axis = axis\n        super(StandardScaler, self).__init__()\n\n    def fit(self, x, mask=None, keepdims=True):\n        if mask is not None:\n            x = np.where(mask, x, np.nan)\n            self.bias = np.nanmean(x, axis=self.axis, keepdims=keepdims)\n            self.scale = np.nanstd(x, axis=self.axis, keepdims=keepdims)\n        else:\n            self.bias = x.mean(axis=self.axis, keepdims=keepdims)\n            self.scale = x.std(axis=self.axis, keepdims=keepdims)\n        return self\n\n\nclass MinMaxScaler(Scaler):\n    def __init__(self, axis=0):\n        self.axis = axis\n        super(MinMaxScaler, self).__init__()\n\n    def fit(self, x, mask=None, keepdims=True):\n        if mask is not None:\n            x = np.where(mask, x, np.nan)\n            self.bias = np.nanmin(x, axis=self.axis, keepdims=keepdims)\n            self.scale = (np.nanmax(x, axis=self.axis, keepdims=keepdims) - self.bias)\n        else:\n            self.bias = x.min(axis=self.axis, keepdims=keepdims)\n            self.scale = (x.max(axis=self.axis, keepdims=keepdims) - self.bias)\n        return self\n"
  },
  {
    "path": "lib/data/spatiotemporal_dataset.py",
    "content": "import numpy as np\nimport pandas as pd\nfrom einops import rearrange\n\nfrom .temporal_dataset import TemporalDataset\n\n\nclass SpatioTemporalDataset(TemporalDataset):\n    def __init__(self, data,\n                 index=None,\n                 trend=None,\n                 scaler=None,\n                 freq=None,\n                 window=24,\n                 horizon=24,\n                 delay=0,\n                 stride=1,\n                 **exogenous):\n        \"\"\"\n        Pytorch dataset for data that can be represented as a single TimeSeries\n\n        :param data:\n            raw target time series (ts) (can be multivariate), shape: [steps, (features), nodes]\n        :param exog:\n            global exogenous variables, shape: [steps, nodes]\n        :param trend:\n            trend time series to be removed from the ts, shape: [steps, (features), (nodes)]\n        :param bias:\n            bias to be removed from the ts (after de-trending), shape [steps, (features), (nodes)]\n        :param scale: r\n            scaling factor to scale the ts (after de-trending), shape [steps, (features), (nodes)]\n        :param mask:\n            mask for valid data, 1 -> valid time step, 0 -> invalid. same shape of ts.\n        :param target_exog:\n            exogenous variables of the target, shape: [steps, nodes]\n        :param window:\n            length of windows returned by __get_intem__\n        :param horizon:\n            length of prediction horizon returned by __get_intem__\n        :param delay:\n            delay between input and prediction\n        \"\"\"\n        super(SpatioTemporalDataset, self).__init__(data,\n                                                    index=index,\n                                                    trend=trend,\n                                                    scaler=scaler,\n                                                    freq=freq,\n                                                    window=window,\n                                                    horizon=horizon,\n                                                    delay=delay,\n                                                    stride=stride,\n                                                    **exogenous)\n\n    def __repr__(self):\n        return \"{}(n_samples={}, n_nodes={})\".format(self.__class__.__name__, len(self), self.n_nodes)\n\n    @property\n    def n_nodes(self):\n        return self.data.shape[1]\n\n    @staticmethod\n    def check_dim(data):\n        if data.ndim == 2:  # [steps, nodes] -> [steps, nodes, features]\n            data = rearrange(data, 's (n f) -> s n f', f=1)\n        elif data.ndim == 1:\n            data = rearrange(data, '(s n f) -> s n f', n=1, f=1)\n        elif data.ndim == 3:\n            pass\n        else:\n            raise ValueError(f'Invalid data dimensions {data.shape}')\n        return data\n\n    def dataframe(self):\n        if self.n_channels == 1:\n            return pd.DataFrame(data=np.squeeze(self.data, -1), index=self.index)\n        raise NotImplementedError()\n"
  },
  {
    "path": "lib/data/temporal_dataset.py",
    "content": "import numpy as np\nimport pandas as pd\nimport torch\nfrom einops import rearrange\nfrom pandas import DatetimeIndex\nfrom torch.utils.data import Dataset\n\nfrom .preprocessing import AbstractScaler\n\n\nclass TemporalDataset(Dataset):\n    def __init__(self, data,\n                 index=None,\n                 freq=None,\n                 exogenous=None,\n                 trend=None,\n                 scaler=None,\n                 window=24,\n                 horizon=24,\n                 delay=0,\n                 stride=1):\n        \"\"\"Wrapper class for dataset whose entry are dependent from a sequence of temporal indices.\n\n        Parameters\n        ----------\n        data : np.ndarray\n            Data relative to the main signal.\n        index : DatetimeIndex or None\n            Temporal indices for the data.\n        exogenous : dict or None\n            Exogenous data and label paired with main signal (default is None).\n        trend : np.ndarray or None\n            Trend paired with main signal (default is None). Must be of the same length of 'data'.\n        scaler : AbstractScaler or None\n            Scaler that must be used for data (default is None).\n        freq : pd.DateTimeIndex.freq or str\n            Frequency of the indices (defaults is indices.freq).\n        window : int\n            Size of the sliding window in the past.\n        horizon : int\n            Size of the prediction horizon.\n        delay : int\n            Offset between end of window and start of horizon.\n\n        Raises\n        ----------\n        ValueError\n            If a frequency for the temporal indices is not provided neither in indices nor explicitly.\n            If preprocess is True and data_scaler is None.\n        \"\"\"\n        super(TemporalDataset, self).__init__()\n        # Initialize signatures\n        self.__exogenous_keys = dict()\n        self.__reserved_signature = {'data', 'trend', 'x', 'y'}\n        # Store data\n        self.data = data\n        if exogenous is not None:\n            for name, value in exogenous.items():\n                self.add_exogenous(value, name, for_window=True, for_horizon=True)\n        # Store time information\n        self.index = index\n        try:\n            freq = freq or index.freq or index.inferred_freq\n            self.freq = pd.tseries.frequencies.to_offset(freq)\n        except AttributeError:\n            self.freq = None\n        # Store offset information\n        self.window = window\n        self.delay = delay\n        self.horizon = horizon\n        self.stride = stride\n        # Identify the indices of the samples\n        self._indices = np.arange(self.data.shape[0] - self.sample_span + 1)[::self.stride]\n        # Store preprocessing options\n        self.trend = trend\n        self.scaler = scaler\n\n    def __getitem__(self, item):\n        return self.get(item, self.preprocess)\n\n    def __contains__(self, item):\n        return item in self.__exogenous_keys\n\n    def __len__(self):\n        return len(self._indices)\n\n    def __repr__(self):\n        return \"{}(n_samples={})\".format(self.__class__.__name__, len(self))\n\n    # Getter and setter for data\n\n    @property\n    def data(self):\n        return self.__data\n\n    @data.setter\n    def data(self, value):\n        assert value is not None\n        self.__data = self.check_input(value)\n\n    @property\n    def trend(self):\n        return self.__trend\n\n    @trend.setter\n    def trend(self, value):\n        self.__trend = self.check_input(value)\n\n    # Setter for exogenous data\n\n    def add_exogenous(self, obj, name, for_window=True, for_horizon=False):\n        assert isinstance(name, str)\n        if name.endswith('_window'):\n            name = name[:-7]\n            for_window, for_horizon = True, False\n        if name.endswith('_horizon'):\n            name = name[:-8]\n            for_window, for_horizon = False, True\n        if name in self.__reserved_signature:\n            raise ValueError(\"Channel '{0}' cannot be added in this way. Use obj.{0} instead.\".format(name))\n        if not (for_window or for_horizon):\n            raise ValueError(\"Either for_window or for_horizon must be True.\")\n        obj = self.check_input(obj)\n        setattr(self, name, obj)\n        self.__exogenous_keys[name] = dict(for_window=for_window, for_horizon=for_horizon)\n        return self\n\n    # Dataset properties\n\n    @property\n    def horizon_offset(self):\n        return self.window + self.delay\n\n    @property\n    def sample_span(self):\n        return max(self.horizon_offset + self.horizon, self.window)\n\n    @property\n    def preprocess(self):\n        return (self.trend is not None) or (self.scaler is not None)\n\n    @property\n    def n_steps(self):\n        return self.data.shape[0]\n\n    @property\n    def n_channels(self):\n        return self.data.shape[-1]\n\n    @property\n    def indices(self):\n        return self._indices\n\n    # Signature information\n\n    @property\n    def exo_window_keys(self):\n        return {k for k, v in self.__exogenous_keys.items() if v['for_window']}\n\n    @property\n    def exo_horizon_keys(self):\n        return {k for k, v in self.__exogenous_keys.items() if v['for_horizon']}\n\n    @property\n    def exo_common_keys(self):\n        return self.exo_window_keys.intersection(self.exo_horizon_keys)\n\n    @property\n    def signature(self):\n        attrs = []\n        if self.window > 0:\n            attrs.append('x')\n            for attr in self.exo_window_keys:\n                attrs.append(attr if attr not in self.exo_common_keys else (attr + '_window'))\n        for attr in self.exo_horizon_keys:\n            attrs.append(attr if attr not in self.exo_common_keys else (attr + '_horizon'))\n        attrs.append('y')\n        attrs = tuple(attrs)\n        preprocess = []\n        if self.trend is not None:\n            preprocess.append('trend')\n        if self.scaler is not None:\n            preprocess.extend(self.scaler.params())\n        preprocess = tuple(preprocess)\n        return dict(data=attrs, preprocessing=preprocess)\n\n    # Item getters\n\n    def get(self, item, preprocess=False):\n        idx = self._indices[item]\n        res, transform = dict(), dict()\n        if self.window > 0:\n            res['x'] = self.data[idx:idx + self.window]\n            for attr in self.exo_window_keys:\n                key = attr if attr not in self.exo_common_keys else (attr + '_window')\n                res[key] = getattr(self, attr)[idx:idx + self.window]\n        for attr in self.exo_horizon_keys:\n            key = attr if attr not in self.exo_common_keys else (attr + '_horizon')\n            res[key] = getattr(self, attr)[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]\n        res['y'] = self.data[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]\n        if preprocess:\n            if self.trend is not None:\n                y_trend = self.trend[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]\n                res['y'] = res['y'] - y_trend\n                transform['trend'] = y_trend\n                if 'x' in res:\n                    res['x'] = res['x'] - self.trend[idx:idx + self.window]\n            if self.scaler is not None:\n                transform.update(self.scaler.params())\n                if 'x' in res:\n                    res['x'] = self.scaler.transform(res['x'])\n        return res, transform\n\n    def snapshot(self, indices=None, preprocess=True):\n        if not self.preprocess:\n            preprocess = False\n        data, prep = [{k: [] for k in sign} for sign in self.signature.values()]\n        indices = np.arange(len(self._indices)) if indices is None else indices\n        for idx in indices:\n            data_i, prep_i = self.get(idx, preprocess)\n            [v.append(data_i[k]) for k, v in data.items()]\n            if len(prep_i):\n                [v.append(prep_i[k]) for k, v in prep.items()]\n        data = {k: np.stack(ds) for k, ds in data.items() if len(ds)}\n        if len(prep):\n            prep = {k: np.stack(ds) if k == 'trend' else ds[0] for k, ds in prep.items() if len(ds)}\n        return data, prep\n\n    # Data utilities\n\n    def expand_indices(self, indices=None, unique=False, merge=False):\n        ds_indices = dict.fromkeys([time for time in ['window', 'horizon'] if getattr(self, time) > 0])\n        indices = np.arange(len(self._indices)) if indices is None else indices\n        if 'window' in ds_indices:\n            w_idxs = [np.arange(idx, idx + self.window) for idx in self._indices[indices]]\n            ds_indices['window'] = np.concatenate(w_idxs)\n        if 'horizon' in ds_indices:\n            h_idxs = [np.arange(idx + self.horizon_offset, idx + self.horizon_offset + self.horizon)\n                      for idx in self._indices[indices]]\n            ds_indices['horizon'] = np.concatenate(h_idxs)\n        if unique:\n            ds_indices = {k: np.unique(v) for k, v in ds_indices.items()}\n        if merge:\n            ds_indices = np.unique(np.concatenate(list(ds_indices.values())))\n        return ds_indices\n\n    def overlapping_indices(self, idxs1, idxs2, synch_mode='window', as_mask=False):\n        assert synch_mode in ['window', 'horizon']\n        ts1 = self.data_timestamps(idxs1, flatten=False)[synch_mode]\n        ts2 = self.data_timestamps(idxs2, flatten=False)[synch_mode]\n        common_ts = np.intersect1d(np.unique(ts1), np.unique(ts2))\n        is_overlapping = lambda sample: np.any(np.in1d(sample, common_ts))\n        m1 = np.apply_along_axis(is_overlapping, 1, ts1)\n        m2 = np.apply_along_axis(is_overlapping, 1, ts2)\n        if as_mask:\n            return m1, m2\n        return np.sort(idxs1[m1]), np.sort(idxs2[m2])\n\n    def data_timestamps(self, indices=None, flatten=True):\n        ds_indices = self.expand_indices(indices, unique=False)\n        ds_timestamps = {k: self.index[v] for k, v in ds_indices.items()}\n        if not flatten:\n            ds_timestamps = {k: np.array(v).reshape(-1, getattr(self, k)) for k, v in ds_timestamps.items()}\n        return ds_timestamps\n\n    def reduce_dataset(self, indices, inplace=False):\n        if not inplace:\n            from copy import deepcopy\n            dataset = deepcopy(self)\n        else:\n            dataset = self\n        old_index = dataset.index[dataset._indices[indices]]\n        ds_indices = dataset.expand_indices(indices, merge=True)\n        dataset.index = dataset.index[ds_indices]\n        dataset.data = dataset.data[ds_indices]\n        if dataset.mask is not None:\n            dataset.mask = dataset.mask[ds_indices]\n        if dataset.trend is not None:\n            dataset.trend = dataset.trend[ds_indices]\n        for attr in dataset.exo_window_keys.union(dataset.exo_horizon_keys):\n            if getattr(dataset, attr, None) is not None:\n                setattr(dataset, attr, getattr(dataset, attr)[ds_indices])\n        dataset._indices = np.flatnonzero(np.in1d(dataset.index, old_index))\n        return dataset\n\n    def check_input(self, data):\n        if data is None:\n            return data\n        data = self.check_dim(data)\n        data = data.clone().detach() if isinstance(data, torch.Tensor) else torch.tensor(data)\n        # cast data\n        if torch.is_floating_point(data):\n            return data.float()\n        elif data.dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]:\n            return data.int()\n        return data\n\n    # Class-specific methods (override in children)\n\n    @staticmethod\n    def check_dim(data):\n        if data.ndim == 1:  # [steps] -> [steps, features]\n            data = rearrange(data, '(s f) -> s f', f=1)\n        elif data.ndim != 2:\n            raise ValueError(f'Invalid data dimensions {data.shape}')\n        return data\n\n    def dataframe(self):\n        return pd.DataFrame(data=self.data, index=self.index)\n\n    @staticmethod\n    def add_argparse_args(parser, **kwargs):\n        parser.add_argument('--window', type=int, default=24)\n        parser.add_argument('--horizon', type=int, default=24)\n        parser.add_argument('--delay', type=int, default=0)\n        parser.add_argument('--stride', type=int, default=1)\n        return parser\n"
  },
  {
    "path": "lib/datasets/__init__.py",
    "content": "from .air_quality import AirQuality\nfrom .metr_la import MissingValuesMetrLA\nfrom .pems_bay import MissingValuesPemsBay\nfrom .synthetic import ChargedParticles\n"
  },
  {
    "path": "lib/datasets/air_quality.py",
    "content": "import os\n\nimport numpy as np\nimport pandas as pd\n\nfrom lib import datasets_path\nfrom .pd_dataset import PandasDataset\nfrom ..utils.utils import disjoint_months, infer_mask, compute_mean, geographical_distance, thresholded_gaussian_kernel\n\n\nclass AirQuality(PandasDataset):\n    SEED = 3210\n\n    def __init__(self, impute_nans=False, small=False, freq='60T', masked_sensors=None):\n        self.random = np.random.default_rng(self.SEED)\n        self.test_months = [3, 6, 9, 12]\n        self.infer_eval_from = 'next'\n        self.eval_mask = None\n        df, dist, mask = self.load(impute_nans=impute_nans, small=small, masked_sensors=masked_sensors)\n        self.dist = dist\n        if masked_sensors is None:\n            self.masked_sensors = list()\n        else:\n            self.masked_sensors = list(masked_sensors)\n        super().__init__(dataframe=df, u=None, mask=mask, name='air', freq=freq, aggr='nearest')\n\n    def load_raw(self, small=False):\n        if small:\n            path = os.path.join(datasets_path['air'], 'small36.h5')\n            eval_mask = pd.DataFrame(pd.read_hdf(path, 'eval_mask'))\n        else:\n            path = os.path.join(datasets_path['air'], 'full437.h5')\n            eval_mask = None\n        df = pd.DataFrame(pd.read_hdf(path, 'pm25'))\n        stations = pd.DataFrame(pd.read_hdf(path, 'stations'))\n        return df, stations, eval_mask\n\n    def load(self, impute_nans=True, small=False, masked_sensors=None):\n        # load readings and stations metadata\n        df, stations, eval_mask = self.load_raw(small)\n        # compute the masks\n        mask = (~np.isnan(df.values)).astype('uint8')  # 1 if value is not nan else 0\n        if eval_mask is None:\n            eval_mask = infer_mask(df, infer_from=self.infer_eval_from)\n\n        eval_mask = eval_mask.values.astype('uint8')\n        if masked_sensors is not None:\n            eval_mask[:, masked_sensors] = np.where(mask[:, masked_sensors], 1, 0)\n        self.eval_mask = eval_mask  # 1 if value is ground-truth for imputation else 0\n        # eventually replace nans with weekly mean by hour\n        if impute_nans:\n            df = df.fillna(compute_mean(df))\n        # compute distances from latitude and longitude degrees\n        st_coord = stations.loc[:, ['latitude', 'longitude']]\n        dist = geographical_distance(st_coord, to_rad=True).values\n        return df, dist, mask\n\n    def splitter(self, dataset, val_len=1., in_sample=False, window=0):\n        nontest_idxs, test_idxs = disjoint_months(dataset, months=self.test_months, synch_mode='horizon')\n        if in_sample:\n            train_idxs = np.arange(len(dataset))\n            val_months = [(m - 1) % 12 for m in self.test_months]\n            _, val_idxs = disjoint_months(dataset, months=val_months, synch_mode='horizon')\n        else:\n            # take equal number of samples before each month of testing\n            val_len = (int(val_len * len(nontest_idxs)) if val_len < 1 else val_len) // len(self.test_months)\n            # get indices of first day of each testing month\n            delta_idxs = np.diff(test_idxs)\n            end_month_idxs = test_idxs[1:][np.flatnonzero(delta_idxs > delta_idxs.min())]\n            if len(end_month_idxs) < len(self.test_months):\n                end_month_idxs = np.insert(end_month_idxs, 0, test_idxs[0])\n            # expand month indices\n            month_val_idxs = [np.arange(v_idx - val_len, v_idx) - window for v_idx in end_month_idxs]\n            val_idxs = np.concatenate(month_val_idxs) % len(dataset)\n            # remove overlapping indices from training set\n            ovl_idxs, _ = dataset.overlapping_indices(nontest_idxs, val_idxs, synch_mode='horizon', as_mask=True)\n            train_idxs = nontest_idxs[~ovl_idxs]\n        return [train_idxs, val_idxs, test_idxs]\n\n    def get_similarity(self, thr=0.1, include_self=False, force_symmetric=False, sparse=False, **kwargs):\n        theta = np.std(self.dist[:36, :36])  # use same theta for both air and air36\n        adj = thresholded_gaussian_kernel(self.dist, theta=theta, threshold=thr)\n        if not include_self:\n            adj[np.diag_indices_from(adj)] = 0.\n        if force_symmetric:\n            adj = np.maximum.reduce([adj, adj.T])\n        if sparse:\n            import scipy.sparse as sps\n            adj = sps.coo_matrix(adj)\n        return adj\n\n    @property\n    def mask(self):\n        return self._mask\n\n    @property\n    def training_mask(self):\n        return self._mask if self.eval_mask is None else (self._mask & (1 - self.eval_mask))\n\n    def test_interval_mask(self, dtype=bool, squeeze=True):\n        m = np.in1d(self.df.index.month, self.test_months).astype(dtype)\n        if squeeze:\n            return m\n        return m[:, None]\n"
  },
  {
    "path": "lib/datasets/metr_la.py",
    "content": "import os\n\nimport numpy as np\nimport pandas as pd\n\nfrom lib import datasets_path\nfrom .pd_dataset import PandasDataset\nfrom ..utils import sample_mask\n\n\nclass MetrLA(PandasDataset):\n    def __init__(self, impute_zeros=False, freq='5T'):\n\n        df, dist, mask = self.load(impute_zeros=impute_zeros)\n        self.dist = dist\n        super().__init__(dataframe=df, u=None, mask=mask, name='la', freq=freq, aggr='nearest')\n\n    def load(self, impute_zeros=True):\n        path = os.path.join(datasets_path['la'], 'metr_la.h5')\n        df = pd.read_hdf(path)\n        datetime_idx = sorted(df.index)\n        date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T')\n        df = df.reindex(index=date_range)\n        mask = ~np.isnan(df.values)\n        if impute_zeros:\n            mask = mask * (df.values != 0.).astype('uint8')\n            df = df.replace(to_replace=0., method='ffill')\n        else:\n            mask = None\n        dist = self.load_distance_matrix()\n        return df, dist, mask\n\n    def load_distance_matrix(self):\n        path = os.path.join(datasets_path['la'], 'metr_la_dist.npy')\n        try:\n            dist = np.load(path)\n        except:\n            distances = pd.read_csv(os.path.join(datasets_path['la'], 'distances_la.csv'))\n            with open(os.path.join(datasets_path['la'], 'sensor_ids_la.txt')) as f:\n                ids = f.read().strip().split(',')\n            num_sensors = len(ids)\n            dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf\n            # Builds sensor id to index map.\n            sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)}\n\n            # Fills cells in the matrix with distances.\n            for row in distances.values:\n                if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:\n                    continue\n                dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]\n            np.save(path, dist)\n        return dist\n\n    def get_similarity(self, thr=0.1, force_symmetric=False, sparse=False):\n        finite_dist = self.dist.reshape(-1)\n        finite_dist = finite_dist[~np.isinf(finite_dist)]\n        sigma = finite_dist.std()\n        adj = np.exp(-np.square(self.dist / sigma))\n        adj[adj < thr] = 0.\n        if force_symmetric:\n            adj = np.maximum.reduce([adj, adj.T])\n        if sparse:\n            import scipy.sparse as sps\n            adj = sps.coo_matrix(adj)\n        return adj\n\n    @property\n    def mask(self):\n        return self._mask\n\n\nclass MissingValuesMetrLA(MetrLA):\n    SEED = 9101112\n\n    def __init__(self, p_fault=0.0015, p_noise=0.05):\n        super(MissingValuesMetrLA, self).__init__(impute_zeros=True)\n        self.rng = np.random.default_rng(self.SEED)\n        self.p_fault = p_fault\n        self.p_noise = p_noise\n        eval_mask = sample_mask(self.numpy().shape,\n                                p=p_fault,\n                                p_noise=p_noise,\n                                min_seq=12,\n                                max_seq=12 * 4,\n                                rng=self.rng)\n        self.eval_mask = (eval_mask & self.mask).astype('uint8')\n\n    @property\n    def training_mask(self):\n        return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))\n\n    def splitter(self, dataset, val_len=0, test_len=0, window=0):\n        idx = np.arange(len(dataset))\n        if test_len < 1:\n            test_len = int(test_len * len(idx))\n        if val_len < 1:\n            val_len = int(val_len * (len(idx) - test_len))\n        test_start = len(idx) - test_len\n        val_start = test_start - val_len\n        return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]]"
  },
  {
    "path": "lib/datasets/pd_dataset.py",
    "content": "import numpy as np\nimport pandas as pd\nimport torch\n\n\nclass PandasDataset:\n    def __init__(self, dataframe: pd.DataFrame, u: pd.DataFrame = None, name='pd-dataset', mask=None, freq=None,\n                 aggr='sum', **kwargs):\n        \"\"\"\n        Initialize a tsl dataset from a pandas dataframe.\n\n\n        :param dataframe: dataframe containing the data, shape: n_steps, n_nodes\n        :param u: dataframe with exog variables\n        :param name: optional name of the dataset\n        :param mask: mask for valid data (1:valid, 0:not valid)\n        :param freq: force a frequency (possibly by resampling)\n        :param aggr: aggregation method after resampling\n        \"\"\"\n        super().__init__()\n        self.name = name\n\n        # set dataset dataframe\n        self.df = dataframe\n\n        # set optional exog_variable dataframe\n        # make sure to consider only the overlapping part of the two dataframes\n        # assumption u.index \\in df.index\n        idx = sorted(self.df.index)\n        self.start = idx[0]\n        self.end = idx[-1]\n\n        if u is not None:\n            self.u = u[self.start:self.end]\n        else:\n            self.u = None\n\n        if mask is not None:\n            mask = np.asarray(mask).astype('uint8')\n        self._mask = mask\n\n        if freq is not None:\n            self.resample_(freq=freq, aggr=aggr)\n        else:\n            self.freq = self.df.index.inferred_freq\n            # make sure that all the dataframes are aligned\n            self.resample_(self.freq, aggr=aggr)\n\n        assert 'T' in self.freq\n        self.samples_per_day = int(60 / int(self.freq[:-1]) * 24)\n\n    def __repr__(self):\n        return \"{}(nodes={}, length={})\".format(self.__class__.__name__, self.n_nodes, self.length)\n\n    @property\n    def has_mask(self):\n        return self._mask is not None\n\n    @property\n    def has_u(self):\n        return self.u is not None\n\n    def resample_(self, freq, aggr):\n        resampler = self.df.resample(freq)\n        idx = self.df.index\n        if aggr == 'sum':\n            self.df = resampler.sum()\n        elif aggr == 'mean':\n            self.df = resampler.mean()\n        elif aggr == 'nearest':\n            self.df = resampler.nearest()\n        else:\n            raise ValueError(f'{aggr} if not a valid aggregation method.')\n\n        if self.has_mask:\n            resampler = pd.DataFrame(self._mask, index=idx).resample(freq)\n            self._mask = resampler.min().to_numpy()\n\n        if self.has_u:\n            resampler = self.u.resample(freq)\n            self.u = resampler.nearest()\n        self.freq = freq\n\n    def dataframe(self) -> pd.DataFrame:\n        return self.df.copy()\n\n    @property\n    def length(self):\n        return self.df.values.shape[0]\n\n    @property\n    def n_nodes(self):\n        return self.df.values.shape[1]\n\n    @property\n    def mask(self):\n        if self._mask is None:\n            return np.ones_like(self.df.values).astype('uint8')\n        return self._mask\n\n    def numpy(self, return_idx=False):\n        if return_idx:\n            return self.numpy(), self.df.index\n        return self.df.values\n\n    def pytorch(self):\n        data = self.numpy()\n        return torch.FloatTensor(data)\n\n    def __len__(self):\n        return self.length\n\n    @staticmethod\n    def build():\n        raise NotImplementedError\n\n    def load_raw(self):\n        raise NotImplementedError\n\n    def load(self):\n        raise NotImplementedError\n"
  },
  {
    "path": "lib/datasets/pems_bay.py",
    "content": "import os\n\nimport numpy as np\nimport pandas as pd\n\nfrom lib import datasets_path\nfrom .pd_dataset import PandasDataset\nfrom ..utils import sample_mask\n\n\nclass PemsBay(PandasDataset):\n    def __init__(self):\n        df, dist, mask = self.load()\n        self.dist = dist\n        super().__init__(dataframe=df, u=None, mask=mask, name='bay', freq='5T', aggr='nearest')\n\n    def load(self, impute_zeros=True):\n        path = os.path.join(datasets_path['bay'], 'pems_bay.h5')\n        df = pd.read_hdf(path)\n        datetime_idx = sorted(df.index)\n        date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T')\n        df = df.reindex(index=date_range)\n        mask = ~np.isnan(df.values)\n        df.fillna(method='ffill', axis=0, inplace=True)\n        dist = self.load_distance_matrix(list(df.columns))\n        return df.astype('float32'), dist, mask.astype('uint8')\n\n    def load_distance_matrix(self, ids):\n        path = os.path.join(datasets_path['bay'], 'pems_bay_dist.npy')\n        try:\n            dist = np.load(path)\n        except:\n            distances = pd.read_csv(os.path.join(datasets_path['bay'], 'distances_bay.csv'))\n            num_sensors = len(ids)\n            dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf\n            # Builds sensor id to index map.\n            sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)}\n\n            # Fills cells in the matrix with distances.\n            for row in distances.values:\n                if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:\n                    continue\n                dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]\n            np.save(path, dist)\n        return dist\n\n    def get_similarity(self, type='dcrnn', thr=0.1, force_symmetric=False, sparse=False):\n        \"\"\"\n        Return similarity matrix among nodes. Implemented to match DCRNN.\n\n        :param type: type of similarity matrix.\n        :param thr: threshold to increase saprseness.\n        :param trainlen: number of steps that can be used for computing the similarity.\n        :param force_symmetric: force the result to be simmetric.\n        :return: and NxN array representig similarity among nodes.\n        \"\"\"\n        if type == 'dcrnn':\n            finite_dist = self.dist.reshape(-1)\n            finite_dist = finite_dist[~np.isinf(finite_dist)]\n            sigma = finite_dist.std()\n            adj = np.exp(-np.square(self.dist / sigma))\n        elif type == 'stcn':\n            sigma = 10\n            adj = np.exp(-np.square(self.dist) / sigma)\n        else:\n            raise NotImplementedError\n        adj[adj < thr] = 0.\n        if force_symmetric:\n            adj = np.maximum.reduce([adj, adj.T])\n        if sparse:\n            import scipy.sparse as sps\n            adj = sps.coo_matrix(adj)\n        return adj\n\n    @property\n    def mask(self):\n        if self._mask is None:\n            return self.df.values != 0.\n        return self._mask\n\n\nclass MissingValuesPemsBay(PemsBay):\n    SEED = 56789\n\n    def __init__(self, p_fault=0.0015, p_noise=0.05):\n        super(MissingValuesPemsBay, self).__init__()\n        self.rng = np.random.default_rng(self.SEED)\n        self.p_fault = p_fault\n        self.p_noise = p_noise\n        eval_mask = sample_mask(self.numpy().shape,\n                                p=p_fault,\n                                p_noise=p_noise,\n                                min_seq=12,\n                                max_seq=12 * 4,\n                                rng=self.rng)\n        self.eval_mask = (eval_mask & self.mask).astype('uint8')\n\n    @property\n    def training_mask(self):\n        return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))\n\n    def splitter(self, dataset, val_len=0, test_len=0, window=0):\n        idx = np.arange(len(dataset))\n        if test_len < 1:\n            test_len = int(test_len * len(idx))\n        if val_len < 1:\n            val_len = int(val_len * (len(idx) - test_len))\n        test_start = len(idx) - test_len\n        val_start = test_start - val_len\n        return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]]\n"
  },
  {
    "path": "lib/datasets/synthetic.py",
    "content": "import os.path\n\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom torch.utils.data import Dataset, DataLoader, Subset\n\nfrom lib import datasets_path\n\n\ndef generate_mask(shape, p_block=0.01, p_point=0.01, max_seq=1, min_seq=1, rng=None):\n    \"\"\"Generate mask in which 1 denotes valid values, 0 missing ones. Assuming shape=(steps, ...).\"\"\"\n    if rng is None:\n        rand = np.random.random\n        randint = np.random.randint\n    else:\n        rand = rng.random\n        randint = rng.integers\n    # init mask\n    mask = np.ones(shape, dtype='uint8')\n    # block missing\n    if p_block > 0:\n        assert max_seq >= min_seq\n        for col in range(shape[1]):\n            i = 0\n            while i < shape[0]:\n                if rand() > p_block:\n                    i += 1\n                else:\n                    fault_len = int(randint(min_seq, max_seq + 1))\n                    mask[i:i + fault_len, col] = 0\n                    i += fault_len + 1  # at least one valid value between two blocks\n    # point missing\n    # let values before and after block missing always valid\n    diff = np.zeros(mask.shape, dtype='uint8')\n    diff[:-1] |= np.diff(mask, axis=0) < 0\n    diff[1:] |= np.diff(mask, axis=0) > 0\n    mask = np.where(mask - diff, rand(shape) > p_point, mask)\n    return mask\n\n\nclass SyntheticDataset(Dataset):\n    SEED: int\n\n    def __init__(self, filename,\n                 window=None,\n                 p_block=0.05,\n                 p_point=0.05,\n                 max_seq=6,\n                 min_seq=4,\n                 use_exogenous=True,\n                 mask_exogenous=True,\n                 graph_mode=True):\n        super(SyntheticDataset, self).__init__()\n        self.mask_exogenous = mask_exogenous\n        self.use_exogenous = use_exogenous\n        self.graph_mode = graph_mode\n        # fetch data\n        content = self.load(filename)\n        self.window = window if window is not None else content['loc'].shape[1]\n        self.loc = torch.tensor(content['loc'][:, :self.window]).float()\n        self.vel = torch.tensor(content['vel'][:, :self.window]).float()\n        self.adj = content['adj']\n        self.SEED = content['seed'].item()\n        # compute masks\n        self.rng = np.random.default_rng(self.SEED)\n        mask_shape = (len(self), self.window, self.n_nodes, 1)\n        mask = generate_mask(mask_shape,\n                             p_block=p_block,\n                             p_point=p_point,\n                             max_seq=max_seq,\n                             min_seq=min_seq,\n                             rng=self.rng).repeat(self.n_channels, -1)\n        eval_mask = 1 - generate_mask(mask_shape,\n                                      p_block=p_block,\n                                      p_point=p_point,\n                                      max_seq=max_seq,\n                                      min_seq=min_seq,\n                                      rng=self.rng).repeat(self.n_channels, -1)\n        self.mask = torch.tensor(mask).byte()\n        self.eval_mask = torch.tensor(eval_mask).byte() & self.mask\n        # store splitting indices\n        self.train_idxs = None\n        self.val_idxs = None\n        self.test_idxs = None\n\n    def __len__(self):\n        return self.loc.size(0)\n\n    def __getitem__(self, index):\n        eval_mask = self.eval_mask[index]\n        mask = self.training_mask[index]\n        x = mask * self.loc[index]\n        res = dict(x=x, mask=mask, eval_mask=eval_mask)\n        if self.use_exogenous:\n            u = self.vel[index]\n            if self.mask_exogenous:\n                u *= mask.all(-1, keepdims=True)\n            res.update(u=u)\n        res.update(y=self.loc[index])\n        if not self.graph_mode:\n            res = {k: rearrange(v, 's n f -> s (n f)') for k, v in res.items()}\n        return res\n\n    @property\n    def n_channels(self):\n        return self.loc.size(-1)\n\n    @property\n    def n_nodes(self):\n        return self.loc.size(-2)\n\n    @property\n    def n_exogenous(self):\n        return self.vel.size(-1) if self.use_exogenous else 0\n\n    @property\n    def training_mask(self):\n        return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))\n\n    @staticmethod\n    def load(filename):\n        return np.load(filename)\n\n    def get_similarity(self, sparse=False):\n        return self.adj\n\n    # Splitting options\n\n    def split(self, val_len=0, test_len=0):\n        idx = np.arange(len(self))\n        if test_len < 1:\n            test_len = int(test_len * len(idx))\n        if val_len < 1:\n            val_len = int(val_len * (len(idx) - test_len))\n        test_start = len(idx) - test_len\n        val_start = test_start - val_len\n        # split dataset\n        self.train_idxs = idx[:val_start]\n        self.val_idxs = idx[val_start:test_start]\n        self.test_idxs = idx[test_start:]\n\n    def train_dataloader(self, shuffle=True, batch_size=32):\n        return DataLoader(Subset(self, self.train_idxs), shuffle=shuffle, batch_size=batch_size, drop_last=True)\n\n    def val_dataloader(self, shuffle=False, batch_size=32):\n        return DataLoader(Subset(self, self.val_idxs), shuffle=shuffle, batch_size=batch_size)\n\n    def test_dataloader(self, shuffle=False, batch_size=32):\n        return DataLoader(Subset(self, self.test_idxs), shuffle=shuffle, batch_size=batch_size)\n\n\nclass ChargedParticles(SyntheticDataset):\n\n    def __init__(self, static_adj=False,\n                 window=None,\n                 p_block=0.05,\n                 p_point=0.05,\n                 max_seq=6,\n                 min_seq=4,\n                 use_exogenous=True,\n                 mask_exogenous=True,\n                 graph_mode=True):\n        if static_adj:\n            filename = os.path.join(datasets_path['synthetic'], 'charged_static.npz')\n        else:\n            filename = os.path.join(datasets_path['synthetic'], 'charged_varying.npz')\n        self.static_adj = static_adj\n        super(ChargedParticles, self).__init__(filename, window,\n                                               p_block=p_block,\n                                               p_point=p_point,\n                                               max_seq=max_seq,\n                                               min_seq=min_seq,\n                                               use_exogenous=use_exogenous,\n                                               mask_exogenous=mask_exogenous,\n                                               graph_mode=graph_mode)\n        charges = self.load(filename)['charges']\n        self.charges = torch.tensor(charges).float()\n\n    def __getitem__(self, item):\n        res = super(ChargedParticles, self).__getitem__(item)\n        # add charges as exogenous features\n        if self.use_exogenous:\n            charges = self.charges[item] if not self.static_adj else self.charges\n            stacked_charges = charges[None].expand(self.window, -1, -1)\n            if not self.graph_mode:\n                stacked_charges = rearrange(stacked_charges, 's n f -> s (n f)')\n            res.update(u=torch.cat([res['u'], stacked_charges], -1))\n        return res\n\n    def get_similarity(self, sparse=False):\n        return np.ones((self.n_nodes, self.n_nodes)) - np.eye(self.n_nodes)\n\n    @property\n    def n_exogenous(self):\n        if self.use_exogenous:\n            return super(ChargedParticles, self).n_exogenous + 1  # add charges to features\n        return 0\n"
  },
  {
    "path": "lib/fillers/__init__.py",
    "content": "from .filler import Filler\nfrom .britsfiller import BRITSFiller\nfrom .graphfiller import GraphFiller\nfrom .rgainfiller import RGAINFiller\nfrom .multi_imputation_filler import MultiImputationFiller\n"
  },
  {
    "path": "lib/fillers/britsfiller.py",
    "content": "import torch\n\nfrom . import Filler\nfrom ..nn import BRITS\n\n\nclass BRITSFiller(Filler):\n\n    def training_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        mask = batch_data['mask'].clone().detach()\n        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute predictions and compute loss\n        out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False)\n\n        if self.scaled_target:\n            target = self._preprocess(y, batch_preprocessing)\n        else:\n            target = y\n            imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations]\n            predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions]\n\n        loss = sum([self.loss_fn(pred, target, mask) for pred in predictions])\n        loss += BRITS.consistency_loss(*imputations)\n\n        # Logging\n        metrics_mask = (mask | eval_mask) - batch_data['mask']  # all unseen data\n        out = self._postprocess(out, batch_preprocessing)\n        self.train_metrics.update(out.detach(), y, metrics_mask)\n        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('train_loss', loss, on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        mask = batch_data.get('mask')\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute predictions and compute loss\n        out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False)\n\n        if self.scaled_target:\n            target = self._preprocess(y, batch_preprocessing)\n        else:\n            target = y\n            predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions]\n\n        val_loss = sum([self.loss_fn(pred, target, mask) for pred in predictions])\n\n        # Logging\n        out = self._postprocess(out, batch_preprocessing)\n        self.val_metrics.update(out.detach(), y, eval_mask)\n        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return val_loss\n\n    def test_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute outputs and rescale\n        imputation, *_ = self.predict_batch(batch, preprocess=False, postprocess=True)\n        test_loss = self.loss_fn(imputation, y, eval_mask)\n\n        # Logging\n        self.test_metrics.update(imputation.detach(), y, eval_mask)\n        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return test_loss\n"
  },
  {
    "path": "lib/fillers/filler.py",
    "content": "import inspect\nfrom copy import deepcopy\n\nimport pytorch_lightning as pl\nimport torch\nfrom pytorch_lightning.core.decorators import auto_move_data\nfrom pytorch_lightning.metrics import MetricCollection\nfrom pytorch_lightning.utilities import move_data_to_device\n\nfrom .. import epsilon\nfrom ..nn.utils.metric_base import MaskedMetric\nfrom ..utils.utils import ensure_list\n\n\nclass Filler(pl.LightningModule):\n    def __init__(self,\n                 model_class,\n                 model_kwargs,\n                 optim_class,\n                 optim_kwargs,\n                 loss_fn,\n                 scaled_target=False,\n                 whiten_prob=0.05,\n                 metrics=None,\n                 scheduler_class=None,\n                 scheduler_kwargs=None):\n        \"\"\"\n        PL module to implement hole fillers.\n\n        :param model_class: Class of pytorch nn.Module implementing the imputer.\n        :param model_kwargs: Model's keyword arguments.\n        :param optim_class: Optimizer class.\n        :param optim_kwargs: Optimizer's keyword arguments.\n        :param loss_fn: Loss function used for training.\n        :param scaled_target: Whether to scale target before computing loss using batch processing information.\n        :param whiten_prob: Probability of removing a value and using it as ground truth for imputation.\n        :param metrics: Dictionary of type {'metric1_name':metric1_fn, 'metric2_name':metric2_fn ...}.\n        :param scheduler_class: Scheduler class.\n        :param scheduler_kwargs: Scheduler's keyword arguments.\n        \"\"\"\n        super(Filler, self).__init__()\n        self.save_hyperparameters(model_kwargs)\n        self.model_cls = model_class\n        self.model_kwargs = model_kwargs\n        self.optim_class = optim_class\n        self.optim_kwargs = optim_kwargs\n        self.scheduler_class = scheduler_class\n        if scheduler_kwargs is None:\n            self.scheduler_kwargs = dict()\n        else:\n            self.scheduler_kwargs = scheduler_kwargs\n\n        if loss_fn is not None:\n            self.loss_fn = self._check_metric(loss_fn, on_step=True)\n        else:\n            self.loss_fn = None\n\n        self.scaled_target = scaled_target\n\n        # during training whiten ground-truth values with this probability\n        assert 0. <= whiten_prob <= 1.\n        self.keep_prob = 1. - whiten_prob\n\n        if metrics is None:\n            metrics = dict()\n        self._set_metrics(metrics)\n        # instantiate model\n        self.model = self.model_cls(**self.model_kwargs)\n\n    def reset_model(self):\n        self.model = self.model_cls(**self.model_kwargs)\n\n    @property\n    def trainable_parameters(self):\n        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)\n\n    @auto_move_data\n    def forward(self, *args, **kwargs):\n        return self.model(*args, **kwargs)\n\n    @staticmethod\n    def _check_metric(metric, on_step=False):\n        if not isinstance(metric, MaskedMetric):\n            if 'reduction' in inspect.getfullargspec(metric).args:\n                metric_kwargs = {'reduction': 'none'}\n            else:\n                metric_kwargs = dict()\n            return MaskedMetric(metric, compute_on_step=on_step, metric_kwargs=metric_kwargs)\n        return deepcopy(metric)\n\n    def _set_metrics(self, metrics):\n        self.train_metrics = MetricCollection(\n            {f'train_{k}': self._check_metric(m, on_step=True) for k, m in metrics.items()})\n        self.val_metrics = MetricCollection({f'val_{k}': self._check_metric(m) for k, m in metrics.items()})\n        self.test_metrics = MetricCollection({f'test_{k}': self._check_metric(m) for k, m in metrics.items()})\n\n    def _preprocess(self, data, batch_preprocessing):\n        \"\"\"\n        Perform preprocessing of a given input.\n\n        :param data: pytorch tensor of shape [batch, steps, nodes, features] to preprocess\n        :param batch_preprocessing: dictionary containing preprocessing data\n        :return: preprocessed data\n        \"\"\"\n        if isinstance(data, (list, tuple)):\n            return [self._preprocess(d, batch_preprocessing) for d in data]\n        trend = batch_preprocessing.get('trend', 0.)\n        bias = batch_preprocessing.get('bias', 0.)\n        scale = batch_preprocessing.get('scale', 1.)\n        return (data - trend - bias) / (scale + epsilon)\n\n    def _postprocess(self, data, batch_preprocessing):\n        \"\"\"\n        Perform postprocessing (inverse transform) of a given input.\n\n        :param data: pytorch tensor of shape [batch, steps, nodes, features] to trasform\n        :param batch_preprocessing: dictionary containing preprocessing data\n        :return: inverse transformed data\n        \"\"\"\n        if isinstance(data, (list, tuple)):\n            return [self._postprocess(d, batch_preprocessing) for d in data]\n        trend = batch_preprocessing.get('trend', 0.)\n        bias = batch_preprocessing.get('bias', 0.)\n        scale = batch_preprocessing.get('scale', 1.)\n        return data * (scale + epsilon) + bias + trend\n\n    def predict_batch(self, batch, preprocess=False, postprocess=True, return_target=False):\n        \"\"\"\n        This method takes as an input a batch as a two dictionaries containing tensors and outputs the predictions.\n        Prediction should have a shape [batch, nodes, horizon]\n\n        :param batch: list dictionary following the structure [data:\n                                                                {'x':[...], 'y':[...], 'u':[...], ...},\n                                                              preprocessing:\n                                                                {'bias': ..., 'scale': ..., 'x_trend':[...], 'y_trend':[...]}]\n        :param preprocess: whether the data need to be preprocessed (note that inputs are by default preprocessed before creating the batch)\n        :param postprocess: whether to postprocess the predictions (if True we assume that the model has learned to predict the trasformed signal)\n        :param return_target: whether to return the prediction target y_true and the prediction mask\n        :return: (y_true), y_hat, (mask)\n        \"\"\"\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n        if preprocess:\n            x = batch_data.pop('x')\n            x = self._preprocess(x, batch_preprocessing)\n            y_hat = self.forward(x, **batch_data)\n        else:\n            y_hat = self.forward(**batch_data)\n        # Rescale outputs\n        if postprocess:\n            y_hat = self._postprocess(y_hat, batch_preprocessing)\n        if return_target:\n            y = batch_data.get('y')\n            mask = batch_data.get('mask', None)\n            return y, y_hat, mask\n        return y_hat\n\n    def predict_loader(self, loader, preprocess=False, postprocess=True, return_mask=True):\n        \"\"\"\n        Makes predictions for an input dataloader. Returns both the predictions and the predictions targets.\n\n        :param loader: torch dataloader\n        :param preprocess: whether to preprocess the data\n        :param postprocess: whether to postprocess the data\n        :param return_mask: whether to return the valid mask (if it exists)\n        :return: y_true, y_hat\n        \"\"\"\n        targets, imputations, masks = [], [], []\n        for batch in loader:\n            batch = move_data_to_device(batch, self.device)\n            batch_data, batch_preprocessing = self._unpack_batch(batch)\n            # Extract mask and target\n            eval_mask = batch_data.pop('eval_mask', None)\n            y = batch_data.pop('y')\n\n            y_hat = self.predict_batch(batch, preprocess=preprocess, postprocess=postprocess)\n\n            if isinstance(y_hat, (list, tuple)):\n                y_hat = y_hat[0]\n\n            targets.append(y)\n            imputations.append(y_hat)\n            masks.append(eval_mask)\n\n        y = torch.cat(targets, 0)\n        y_hat = torch.cat(imputations, 0)\n        if return_mask:\n            mask = torch.cat(masks, 0) if masks[0] is not None else None\n            return y, y_hat, mask\n        return y, y_hat\n\n    def _unpack_batch(self, batch):\n        \"\"\"\n        Unpack a batch into data and preprocessing dictionaries.\n\n        :param batch: the batch\n        :return: batch_data, batch_preprocessing\n        \"\"\"\n        if isinstance(batch, (tuple, list)) and (len(batch) == 2):\n            batch_data, batch_preprocessing = batch\n        else:\n            batch_data = batch\n            batch_preprocessing = dict()\n        return batch_data, batch_preprocessing\n\n    def training_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        mask = batch_data['mask'].clone().detach()\n        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()\n        eval_mask = batch_data.pop('eval_mask')\n        eval_mask = (mask | eval_mask) - batch_data['mask']\n\n        y = batch_data.pop('y')\n\n        # Compute predictions and compute loss\n        imputation = self.predict_batch(batch, preprocess=False, postprocess=False)\n\n        if self.scaled_target:\n            target = self._preprocess(y, batch_preprocessing)\n        else:\n            target = y\n            imputation = self._postprocess(imputation, batch_preprocessing)\n\n        loss = self.loss_fn(imputation, target, mask)\n\n        # Logging\n        if self.scaled_target:\n            imputation = self._postprocess(imputation, batch_preprocessing)\n        self.train_metrics.update(imputation.detach(), y, eval_mask)  # all unseen data\n        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute predictions and compute loss\n        imputation = self.predict_batch(batch, preprocess=False, postprocess=False)\n\n        if self.scaled_target:\n            target = self._preprocess(y, batch_preprocessing)\n        else:\n            target = y\n            imputation = self._postprocess(imputation, batch_preprocessing)\n\n        val_loss = self.loss_fn(imputation, target, eval_mask)\n\n        # Logging\n        if self.scaled_target:\n            imputation = self._postprocess(imputation, batch_preprocessing)\n        self.val_metrics.update(imputation.detach(), y, eval_mask)\n        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return val_loss\n\n    def test_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute outputs and rescale\n        imputation = self.predict_batch(batch, preprocess=False, postprocess=True)\n        test_loss = self.loss_fn(imputation, y, eval_mask)\n\n        # Logging\n        self.test_metrics.update(imputation.detach(), y, eval_mask)\n        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        return test_loss\n\n    def on_train_epoch_start(self) -> None:\n        optimizers = ensure_list(self.optimizers())\n        for i, optimizer in enumerate(optimizers):\n            lr = optimizer.optimizer.param_groups[0]['lr']\n            self.log(f'lr_{i}', lr, on_step=False, on_epoch=True, logger=True, prog_bar=False)\n\n    def configure_optimizers(self):\n        cfg = dict()\n        optimizer = self.optim_class(self.parameters(), **self.optim_kwargs)\n        cfg['optimizer'] = optimizer\n        if self.scheduler_class is not None:\n            metric = self.scheduler_kwargs.pop('monitor', None)\n            scheduler = self.scheduler_class(optimizer, **self.scheduler_kwargs)\n            cfg['lr_scheduler'] = scheduler\n            if metric is not None:\n                cfg['monitor'] = metric\n        return cfg\n"
  },
  {
    "path": "lib/fillers/graphfiller.py",
    "content": "import torch\n\nfrom . import Filler\nfrom ..nn.models import MPGRUNet, GRINet, BiMPGRUNet\n\n\nclass GraphFiller(Filler):\n\n    def __init__(self,\n                 model_class,\n                 model_kwargs,\n                 optim_class,\n                 optim_kwargs,\n                 loss_fn,\n                 scaled_target=False,\n                 whiten_prob=0.05,\n                 pred_loss_weight=1.,\n                 warm_up=0,\n                 metrics=None,\n                 scheduler_class=None,\n                 scheduler_kwargs=None):\n        super(GraphFiller, self).__init__(model_class=model_class,\n                                          model_kwargs=model_kwargs,\n                                          optim_class=optim_class,\n                                          optim_kwargs=optim_kwargs,\n                                          loss_fn=loss_fn,\n                                          scaled_target=scaled_target,\n                                          whiten_prob=whiten_prob,\n                                          metrics=metrics,\n                                          scheduler_class=scheduler_class,\n                                          scheduler_kwargs=scheduler_kwargs)\n\n        self.tradeoff = pred_loss_weight\n        if model_class is MPGRUNet:\n            self.trimming = (warm_up, 0)\n        elif model_class in [GRINet, BiMPGRUNet]:\n            self.trimming = (warm_up, warm_up)\n\n    def trim_seq(self, *seq):\n        seq = [s[:, self.trimming[0]:s.size(1) - self.trimming[1]] for s in seq]\n        if len(seq) == 1:\n            return seq[0]\n        return seq\n\n    def training_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Compute masks\n        mask = batch_data['mask'].clone().detach()\n        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()\n        eval_mask = batch_data.pop('eval_mask', None)\n        eval_mask = (mask | eval_mask) - batch_data['mask']  # all unseen data\n\n        y = batch_data.pop('y')\n\n        # Compute predictions and compute loss\n        res = self.predict_batch(batch, preprocess=False, postprocess=False)\n        imputation, predictions = (res[0], res[1:]) if isinstance(res, (list, tuple)) else (res, [])\n\n        # trim to imputation horizon len\n        imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y)\n        predictions = self.trim_seq(*predictions)\n\n        if self.scaled_target:\n            target = self._preprocess(y, batch_preprocessing)\n        else:\n            target = y\n            imputation = self._postprocess(imputation, batch_preprocessing)\n            for i, _ in enumerate(predictions):\n                predictions[i] = self._postprocess(predictions[i], batch_preprocessing)\n\n        loss = self.loss_fn(imputation, target, mask)\n        for pred in predictions:\n            loss += self.tradeoff * self.loss_fn(pred, target, mask)\n\n        # Logging\n        if self.scaled_target:\n            imputation = self._postprocess(imputation, batch_preprocessing)\n        self.train_metrics.update(imputation.detach(), y, eval_mask)  # all unseen data\n        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        mask = batch_data.get('mask')\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute predictions and compute loss\n        imputation = self.predict_batch(batch, preprocess=False, postprocess=False)\n\n        # trim to imputation horizon len\n        imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y)\n\n        if self.scaled_target:\n            target = self._preprocess(y, batch_preprocessing)\n        else:\n            target = y\n            imputation = self._postprocess(imputation, batch_preprocessing)\n\n        val_loss = self.loss_fn(imputation, target, eval_mask)\n\n        # Logging\n        if self.scaled_target:\n            imputation = self._postprocess(imputation, batch_preprocessing)\n        self.val_metrics.update(imputation.detach(), y, eval_mask)\n        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return val_loss\n\n    def test_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute outputs and rescale\n        imputation = self.predict_batch(batch, preprocess=False, postprocess=True)\n        test_loss = self.loss_fn(imputation, y, eval_mask)\n\n        # Logging\n        self.test_metrics.update(imputation.detach(), y, eval_mask)\n        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return test_loss\n"
  },
  {
    "path": "lib/fillers/multi_imputation_filler.py",
    "content": "import torch\nfrom pytorch_lightning.core.decorators import auto_move_data\n\nfrom . import Filler\n\n\nclass MultiImputationFiller(Filler):\n    \"\"\"\n    Filler with multiple imputation outputs\n    \"\"\"\n\n    def __init__(self,\n                 model_class,\n                 model_kwargs,\n                 optim_class,\n                 optim_kwargs,\n                 loss_fn,\n                 consistency_loss=False,\n                 scaled_target=False,\n                 whiten_prob=0.05,\n                 metrics=None,\n                 scheduler_class=None,\n                 scheduler_kwargs=None):\n\n        super().__init__(model_class,\n                         model_kwargs,\n                         optim_class,\n                         optim_kwargs,\n                         loss_fn,\n                         scaled_target,\n                         whiten_prob,\n                         metrics,\n                         scheduler_class,\n                         scheduler_kwargs)\n        self.consistency_loss = consistency_loss\n\n    @auto_move_data\n    def forward(self, *args, **kwargs):\n        out = self.model(*args, **kwargs)\n        assert isinstance(out, (list, tuple))\n        if self.training:\n            return out\n        return out[0]  # we assume that the final imputation is the first one\n\n    def _consistency_loss(self, imputations, mask):\n        from itertools import combinations\n        return sum([self.loss_fn(imp1, imp2, mask) for imp1, imp2 in combinations(imputations, 2)])\n\n    def training_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n\n        # Extract mask and target\n        mask = batch_data['mask'].clone().detach()\n        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        # Compute predictions and compute loss\n        imputations = self.predict_batch(batch, preprocess=False, postprocess=False)\n\n        if self.scaled_target:\n            target = self._preprocess(y, batch_preprocessing)\n        else:\n            target = y\n            imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations]\n\n        loss = sum([self.loss_fn(imp, target, mask) for imp in imputations])\n        if self.consistency_loss:\n            loss += self._consistency_loss(imputations, mask)\n\n        # Logging\n        metrics_mask = (mask | eval_mask) - batch_data['mask']  # all unseen data\n\n        x_hat = imputations[0]\n        x_hat = self._postprocess(x_hat, batch_preprocessing)\n        self.train_metrics.update(x_hat.detach(), y, metrics_mask)\n        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n        self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n        return loss\n"
  },
  {
    "path": "lib/fillers/rgainfiller.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\nfrom .multi_imputation_filler import MultiImputationFiller\nfrom ..nn.utils.metric_base import MaskedMetric\n\n\nclass MaskedBCEWithLogits(MaskedMetric):\n    def __init__(self,\n                 mask_nans=False,\n                 mask_inf=False,\n                 compute_on_step=True,\n                 dist_sync_on_step=False,\n                 process_group=None,\n                 dist_sync_fn=None,\n                 at=None):\n        super(MaskedBCEWithLogits, self).__init__(metric_fn=F.binary_cross_entropy_with_logits,\n                                                  mask_nans=mask_nans,\n                                                  mask_inf=mask_inf,\n                                                  compute_on_step=compute_on_step,\n                                                  dist_sync_on_step=dist_sync_on_step,\n                                                  process_group=process_group,\n                                                  dist_sync_fn=dist_sync_fn,\n                                                  metric_kwargs={'reduction': 'none'},\n                                                  at=at)\n\n\nclass RGAINFiller(MultiImputationFiller):\n    def __init__(self,\n                 model_class,\n                 model_kwargs,\n                 optim_class,\n                 optim_kwargs,\n                 loss_fn,\n                 g_train_freq=1,\n                 d_train_freq=5,\n                 consistency_loss=False,\n                 scaled_target=True,\n                 whiten_prob=0.05,\n                 hint_rate=0.7,\n                 alpha=10.,\n                 metrics=None,\n                 scheduler_class=None,\n                 scheduler_kwargs=None):\n        super(RGAINFiller, self).__init__(model_class=model_class,\n                                          model_kwargs=model_kwargs,\n                                          optim_class=optim_class,\n                                          optim_kwargs=optim_kwargs,\n                                          loss_fn=loss_fn,\n                                          scaled_target=scaled_target,\n                                          whiten_prob=whiten_prob,\n                                          metrics=metrics,\n                                          consistency_loss=consistency_loss,\n                                          scheduler_class=scheduler_class,\n                                          scheduler_kwargs=scheduler_kwargs)\n        # discriminator training params\n        self.alpha = alpha\n        self.g_train_freq = g_train_freq\n        self.d_train_freq = d_train_freq\n        self.masked_bce_loss = MaskedBCEWithLogits(compute_on_step=True)\n        # activate manual optimization\n        self.automatic_optimization = False\n        self.hint_rate = hint_rate\n\n    def training_step(self, batch, batch_idx):\n        # Unpack batch\n        batch_data, batch_preprocessing = self._unpack_batch(batch)\n        g_opt, d_opt = self.optimizers()\n        schedulers = self.lr_schedulers()\n\n        # Extract mask and target\n        x = batch_data.pop('x')\n        mask = batch_data['mask'].clone().detach()\n        training_mask = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()\n        eval_mask = batch_data.pop('eval_mask', None)\n        y = batch_data.pop('y')\n\n        ##########################\n        #  generate imputations\n        ##########################\n\n        imputations = self.model.generator(x, training_mask)\n        imputed_seq = imputations[0]\n        target = self._preprocess(y, batch_preprocessing)\n        y_hat = self._postprocess(imputed_seq, batch_preprocessing)\n\n        x_in = training_mask * x + (1 - training_mask) * imputed_seq\n        hint = torch.rand_like(training_mask, dtype=torch.float) < self.hint_rate\n        hint = hint.byte()\n        hint = hint * training_mask + (1 - hint) * 0.5\n\n        #########################\n        #  train generator\n        #########################\n        if (batch_idx % self.g_train_freq) == 0:\n\n            g_opt.zero_grad()\n\n            rec_loss = sum([torch.sqrt(self.loss_fn(imp, target, mask)) for imp in imputations])\n            if self.consistency_loss:\n                rec_loss += self._consistency_loss(imputations, mask)\n\n            logits = self.model.discriminator(x_in, hint)\n            # maximize logit\n            adv_loss = self.masked_bce_loss(logits, torch.ones_like(logits), 1 - training_mask)\n\n            g_loss = self.alpha * rec_loss + adv_loss\n\n            self.manual_backward(g_loss)\n            g_opt.step()\n\n            # Logging\n            metrics_mask = (mask | eval_mask) - training_mask\n            self.train_metrics.update(y_hat.detach(), y, metrics_mask)  # all unseen data\n            self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)\n            self.log('gen_loss', adv_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n            self.log('imp_loss', rec_loss.detach(), on_step=True, on_epoch=True, logger=True, prog_bar=True)\n\n        ###########################\n        # train discriminator\n        ###########################\n\n        if (batch_idx % self.d_train_freq) == 0:\n            d_opt.zero_grad()\n\n            logits = self.model.discriminator(x_in.detach(), hint)\n            d_loss = self.masked_bce_loss(logits, training_mask.to(logits.dtype))\n\n            self.manual_backward(d_loss)\n            d_opt.step()\n            self.log('d_loss', d_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)\n\n        if (schedulers is not None) and self.trainer.is_last_batch:\n            for sch in schedulers:\n                sch.step()\n\n    def configure_optimizers(self):\n        opt_g = self.optim_class(self.model.generator.parameters(), **self.optim_kwargs)\n        opt_d = self.optim_class(self.model.discriminator.parameters(), **self.optim_kwargs)\n        optimizers = [opt_g, opt_d]\n        if self.scheduler_class is not None:\n            metric = self.scheduler_kwargs.pop('monitor', None)\n            schedulers = [{\"scheduler\": self.scheduler_class(opt, **self.scheduler_kwargs), \"monitor\": metric}\n                          for opt in optimizers]\n            return optimizers, schedulers\n        return optimizers\n"
  },
  {
    "path": "lib/nn/__init__.py",
    "content": "from .layers import *\n"
  },
  {
    "path": "lib/nn/layers/__init__.py",
    "content": "from .rits import RITS, BRITS\nfrom .gril import GRIL, BiGRIL\nfrom .spatial_conv import SpatialConvOrderK\nfrom .mpgru import MPGRUImputer\n"
  },
  {
    "path": "lib/nn/layers/gcrnn.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .spatial_conv import SpatialConvOrderK\n\n\nclass GCGRUCell(nn.Module):\n    \"\"\"\n    Graph Convolution Gated Recurrent Unit Cell.\n    \"\"\"\n\n    def __init__(self, d_in, num_units, support_len, order, activation='tanh'):\n        \"\"\"\n        :param num_units: the hidden dim of rnn\n        :param support_len: the (weighted) adjacency matrix of the graph, in numpy ndarray form\n        :param order: the max diffusion step\n        :param activation: if None, don't do activation for cell state\n        \"\"\"\n        super(GCGRUCell, self).__init__()\n        self.activation_fn = getattr(torch, activation)\n\n        self.forget_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len,\n                                             order=order)\n        self.update_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len,\n                                             order=order)\n        self.c_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len, order=order)\n\n    def forward(self, x, h, adj):\n        \"\"\"\n        :param x: (B, input_dim, num_nodes)\n        :param h: (B, num_units, num_nodes)\n        :param adj: (num_nodes, num_nodes)\n        :return:\n        \"\"\"\n        # we start with bias 1.0 to not reset and not update\n        x_gates = torch.cat([x, h], dim=1)\n        r = torch.sigmoid(self.forget_gate(x_gates, adj))\n        u = torch.sigmoid(self.update_gate(x_gates, adj))\n        x_c = torch.cat([x, r * h], dim=1)\n        c = self.c_gate(x_c, adj)  # batch_size, self._num_nodes * output_size\n        c = self.activation_fn(c)\n        return u * h + (1. - u) * c\n\n\nclass GCRNN(nn.Module):\n    def __init__(self,\n                 d_in,\n                 d_model,\n                 d_out,\n                 n_layers,\n                 support_len,\n                 kernel_size=2):\n        super(GCRNN, self).__init__()\n        self.d_in = d_in\n        self.d_model = d_model\n        self.d_out = d_out\n        self.n_layers = n_layers\n        self.ks = kernel_size\n        self.support_len = support_len\n        self.rnn_cells = nn.ModuleList()\n        for i in range(self.n_layers):\n            self.rnn_cells.append(GCGRUCell(d_in=self.d_in if i == 0 else self.d_model,\n                                            num_units=self.d_model, support_len=self.support_len, order=self.ks))\n        self.output_layer = nn.Conv2d(self.d_model, self.d_out, kernel_size=1)\n\n    def init_hidden_states(self, x):\n        return [torch.zeros(size=(x.shape[0], self.d_model, x.shape[2])).to(x.device) for _ in range(self.n_layers)]\n\n    def single_pass(self, x, h, adj):\n        out = x\n        for l, layer in enumerate(self.rnn_cells):\n            out = h[l] = layer(out, h[l], adj)\n        return out, h\n\n    def forward(self, x, adj, h=None):\n        # x:[batch, features, nodes, steps]\n        *_, steps = x.size()\n        if h is None:\n            h = self.init_hidden_states(x)\n        # temporal conv\n        for step in range(steps):\n            out, h = self.single_pass(x[..., step], h, adj)\n\n        return self.output_layer(out[..., None])\n"
  },
  {
    "path": "lib/nn/layers/gril.py",
    "content": "import torch\nimport torch.nn as nn\nfrom einops import rearrange\n\nfrom .spatial_conv import SpatialConvOrderK\nfrom .gcrnn import GCGRUCell\nfrom .spatial_attention import SpatialAttention\nfrom ..utils.ops import reverse_tensor\n\n\nclass SpatialDecoder(nn.Module):\n    def __init__(self, d_in, d_model, d_out, support_len, order=1, attention_block=False, nheads=2, dropout=0.):\n        super(SpatialDecoder, self).__init__()\n        self.order = order\n        self.lin_in = nn.Conv1d(d_in, d_model, kernel_size=1)\n        self.graph_conv = SpatialConvOrderK(c_in=d_model, c_out=d_model,\n                                            support_len=support_len * order, order=1, include_self=False)\n        if attention_block:\n            self.spatial_att = SpatialAttention(d_in=d_model,\n                                                d_model=d_model,\n                                                nheads=nheads,\n                                                dropout=dropout)\n            self.lin_out = nn.Conv1d(3 * d_model, d_model, kernel_size=1)\n        else:\n            self.register_parameter('spatial_att', None)\n            self.lin_out = nn.Conv1d(2 * d_model, d_model, kernel_size=1)\n        self.read_out = nn.Conv1d(2 * d_model, d_out, kernel_size=1)\n        self.activation = nn.PReLU()\n        self.adj = None\n\n    def forward(self, x, m, h, u, adj, cached_support=False):\n        # [batch, channels, nodes]\n        x_in = [x, m, h] if u is None else [x, m, u, h]\n        x_in = torch.cat(x_in, 1)\n        if self.order > 1:\n            if cached_support and (self.adj is not None):\n                adj = self.adj\n            else:\n                adj = SpatialConvOrderK.compute_support_orderK(adj, self.order, include_self=False, device=x_in.device)\n                self.adj = adj if cached_support else None\n\n        x_in = self.lin_in(x_in)\n        out = self.graph_conv(x_in, adj)\n        if self.spatial_att is not None:\n            # [batch, channels, nodes] -> [batch, steps, nodes, features]\n            x_in = rearrange(x_in, 'b f n -> b 1 n f')\n            out_att = self.spatial_att(x_in, torch.eye(x_in.size(2), dtype=torch.bool, device=x_in.device))\n            out_att = rearrange(out_att, 'b s n f -> b f (n s)')\n            out = torch.cat([out, out_att], 1)\n        out = torch.cat([out, h], 1)\n        out = self.activation(self.lin_out(out))\n        # out = self.lin_out(out)\n        out = torch.cat([out, h], 1)\n        return self.read_out(out), out\n\n\nclass GRIL(nn.Module):\n    def __init__(self,\n                 input_size,\n                 hidden_size,\n                 u_size=None,\n                 n_layers=1,\n                 dropout=0.,\n                 kernel_size=2,\n                 decoder_order=1,\n                 global_att=False,\n                 support_len=2,\n                 n_nodes=None,\n                 layer_norm=False):\n        super(GRIL, self).__init__()\n        self.input_size = int(input_size)\n        self.hidden_size = int(hidden_size)\n        self.u_size = int(u_size) if u_size is not None else 0\n        self.n_layers = int(n_layers)\n        rnn_input_size = 2 * self.input_size + self.u_size  # input + mask + (eventually) exogenous\n\n        # Spatio-temporal encoder (rnn_input_size -> hidden_size)\n        self.cells = nn.ModuleList()\n        self.norms = nn.ModuleList()\n        for i in range(self.n_layers):\n            self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size,\n                                        num_units=self.hidden_size, support_len=support_len, order=kernel_size))\n            if layer_norm:\n                self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size))\n            else:\n                self.norms.append(nn.Identity())\n        self.dropout = nn.Dropout(dropout) if dropout > 0. else None\n\n        # Fist stage readout\n        self.first_stage = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1)\n\n        # Spatial decoder (rnn_input_size + hidden_size -> hidden_size)\n        self.spatial_decoder = SpatialDecoder(d_in=rnn_input_size + self.hidden_size,\n                                              d_model=self.hidden_size,\n                                              d_out=self.input_size,\n                                              support_len=2,\n                                              order=decoder_order,\n                                              attention_block=global_att)\n\n        # Hidden state initialization embedding\n        if n_nodes is not None:\n            self.h0 = self.init_hidden_states(n_nodes)\n        else:\n            self.register_parameter('h0', None)\n\n    def init_hidden_states(self, n_nodes):\n        h0 = []\n        for l in range(self.n_layers):\n            std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float))\n            vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes))\n            h0.append(nn.Parameter(vals))\n        return nn.ParameterList(h0)\n\n    def get_h0(self, x):\n        if self.h0 is not None:\n            return [h.expand(x.shape[0], -1, -1) for h in self.h0]\n        return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers\n\n    def update_state(self, x, h, adj):\n        rnn_in = x\n        for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)):\n            rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj))\n            if self.dropout is not None and layer < (self.n_layers - 1):\n                rnn_in = self.dropout(rnn_in)\n        return h\n\n    def forward(self, x, adj, mask=None, u=None, h=None, cached_support=False):\n        # x:[batch, features, nodes, steps]\n        *_, steps = x.size()\n\n        # infer all valid if mask is None\n        if mask is None:\n            mask = torch.ones_like(x, dtype=torch.uint8)\n\n        # init hidden state using node embedding or the empty state\n        if h is None:\n            h = self.get_h0(x)\n        elif not isinstance(h, list):\n            h = [*h]\n\n        # Temporal conv\n        predictions, imputations, states = [], [], []\n        representations = []\n        for step in range(steps):\n            x_s = x[..., step]\n            m_s = mask[..., step]\n            h_s = h[-1]\n            u_s = u[..., step] if u is not None else None\n            # firstly impute missing values with predictions from state\n            xs_hat_1 = self.first_stage(h_s)\n            # fill missing values in input with prediction\n            x_s = torch.where(m_s, x_s, xs_hat_1)\n            # prepare inputs\n            # retrieve maximum information from neighbors\n            xs_hat_2, repr_s = self.spatial_decoder(x=x_s, m=m_s, h=h_s, u=u_s, adj=adj,\n                                                    cached_support=cached_support)  # receive messages from neighbors (no self-loop!)\n            # readout of imputation state + mask to retrieve imputations\n            # prepare inputs\n            x_s = torch.where(m_s, x_s, xs_hat_2)\n            inputs = [x_s, m_s]\n            if u_s is not None:\n                inputs.append(u_s)\n            inputs = torch.cat(inputs, dim=1)  # x_hat_2 + mask + exogenous\n            # update state with original sequence filled using imputations\n            h = self.update_state(inputs, h, adj)\n            # store imputations and states\n            imputations.append(xs_hat_2)\n            predictions.append(xs_hat_1)\n            states.append(torch.stack(h, dim=0))\n            representations.append(repr_s)\n\n        # Aggregate outputs -> [batch, features, nodes, steps]\n        imputations = torch.stack(imputations, dim=-1)\n        predictions = torch.stack(predictions, dim=-1)\n        states = torch.stack(states, dim=-1)\n        representations = torch.stack(representations, dim=-1)\n\n        return imputations, predictions, representations, states\n\n\nclass BiGRIL(nn.Module):\n    def __init__(self,\n                 input_size,\n                 hidden_size,\n                 ff_size,\n                 ff_dropout,\n                 n_layers=1,\n                 dropout=0.,\n                 n_nodes=None,\n                 support_len=2,\n                 kernel_size=2,\n                 decoder_order=1,\n                 global_att=False,\n                 u_size=0,\n                 embedding_size=0,\n                 layer_norm=False,\n                 merge='mlp'):\n        super(BiGRIL, self).__init__()\n        self.fwd_rnn = GRIL(input_size=input_size,\n                            hidden_size=hidden_size,\n                            n_layers=n_layers,\n                            dropout=dropout,\n                            n_nodes=n_nodes,\n                            support_len=support_len,\n                            kernel_size=kernel_size,\n                            decoder_order=decoder_order,\n                            global_att=global_att,\n                            u_size=u_size,\n                            layer_norm=layer_norm)\n        self.bwd_rnn = GRIL(input_size=input_size,\n                            hidden_size=hidden_size,\n                            n_layers=n_layers,\n                            dropout=dropout,\n                            n_nodes=n_nodes,\n                            support_len=support_len,\n                            kernel_size=kernel_size,\n                            decoder_order=decoder_order,\n                            global_att=global_att,\n                            u_size=u_size,\n                            layer_norm=layer_norm)\n\n        if n_nodes is None:\n            embedding_size = 0\n        if embedding_size > 0:\n            self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes))\n            nn.init.kaiming_normal_(self.emb, nonlinearity='relu')\n        else:\n            self.register_parameter('emb', None)\n\n        if merge == 'mlp':\n            self._impute_from_states = True\n            self.out = nn.Sequential(\n                nn.Conv2d(in_channels=4 * hidden_size + input_size + embedding_size,\n                          out_channels=ff_size, kernel_size=1),\n                nn.ReLU(),\n                nn.Dropout(ff_dropout),\n                nn.Conv2d(in_channels=ff_size, out_channels=input_size, kernel_size=1)\n            )\n        elif merge in ['mean', 'sum', 'min', 'max']:\n            self._impute_from_states = False\n            self.out = getattr(torch, merge)\n        else:\n            raise ValueError(\"Merge option %s not allowed.\" % merge)\n        self.supp = None\n\n    def forward(self, x, adj, mask=None, u=None, cached_support=False):\n        if cached_support and (self.supp is not None):\n            supp = self.supp\n        else:\n            supp = SpatialConvOrderK.compute_support(adj, x.device)\n            self.supp = supp if cached_support else None\n        # Forward\n        fwd_out, fwd_pred, fwd_repr, _ = self.fwd_rnn(x, supp, mask=mask, u=u, cached_support=cached_support)\n        # Backward\n        rev_x, rev_mask, rev_u = [reverse_tensor(tens) for tens in (x, mask, u)]\n        *bwd_res, _ = self.bwd_rnn(rev_x, supp, mask=rev_mask, u=rev_u, cached_support=cached_support)\n        bwd_out, bwd_pred, bwd_repr = [reverse_tensor(res) for res in bwd_res]\n\n        if self._impute_from_states:\n            inputs = [fwd_repr, bwd_repr, mask]\n            if self.emb is not None:\n                b, *_, s = fwd_repr.shape  # fwd_h: [batches, channels, nodes, steps]\n                inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)]  # stack emb for batches and steps\n            imputation = torch.cat(inputs, dim=1)\n            imputation = self.out(imputation)\n        else:\n            imputation = torch.stack([fwd_out, bwd_out], dim=1)\n            imputation = self.out(imputation, dim=1)\n\n        predictions = torch.stack([fwd_out, bwd_out, fwd_pred, bwd_pred], dim=0)\n\n        return imputation, predictions\n"
  },
  {
    "path": "lib/nn/layers/imputation.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass ImputationLayer(nn.Module):\n    def __init__(self, d_in, bias=True):\n        super(ImputationLayer, self).__init__()\n        self.W = nn.Parameter(torch.Tensor(d_in, d_in))\n        if bias:\n            self.b = nn.Parameter(torch.Tensor(d_in))\n        else:\n            self.register_buffer('b', None)\n        mask = 1. - torch.eye(d_in)\n        self.register_buffer('mask', mask)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))\n        if self.b is not None:\n            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)\n            bound = 1 / math.sqrt(fan_in)\n            nn.init.uniform_(self.b, -bound, bound)\n\n    def forward(self, x):\n        # batch, features\n        return F.linear(x, self.mask * self.W, self.b)"
  },
  {
    "path": "lib/nn/layers/mpgru.py",
    "content": "import torch\nfrom torch import nn\n\nfrom .gcrnn import GCGRUCell\n\n\nclass MPGRUImputer(nn.Module):\n    def __init__(self,\n                 input_size,\n                 hidden_size,\n                 ff_size=None,\n                 u_size=None,\n                 n_layers=1,\n                 dropout=0.,\n                 kernel_size=2,\n                 support_len=2,\n                 n_nodes=None,\n                 layer_norm=False,\n                 autoencoder_mode=False):\n        super(MPGRUImputer, self).__init__()\n        self.input_size = int(input_size)\n        self.hidden_size = int(hidden_size)\n        self.ff_size = int(ff_size) if ff_size is not None else 0\n        self.u_size = int(u_size) if u_size is not None else 0\n        self.n_layers = int(n_layers)\n        rnn_input_size = 2 * self.input_size + self.u_size  # input + mask + (eventually) exogenous\n\n        # Spatio-temporal encoder (rnn_input_size -> hidden_size)\n        self.cells = nn.ModuleList()\n        self.norms = nn.ModuleList()\n        for i in range(self.n_layers):\n            self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size,\n                                        num_units=self.hidden_size, support_len=support_len, order=kernel_size))\n            if layer_norm:\n                self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size))\n            else:\n                self.norms.append(nn.Identity())\n        self.dropout = nn.Dropout(dropout) if dropout > 0. else None\n\n        # Readout\n        if self.ff_size:\n            self.pred_readout = nn.Sequential(\n                nn.Conv1d(in_channels=self.hidden_size, out_channels=self.ff_size, kernel_size=1),\n                nn.PReLU(),\n                nn.Conv1d(in_channels=self.ff_size, out_channels=self.input_size, kernel_size=1)\n            )\n        else:\n            self.pred_readout = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1)\n\n        # Hidden state initialization embedding\n        if n_nodes is not None:\n            self.h0 = self.init_hidden_states(n_nodes)\n        else:\n            self.register_parameter('h0', None)\n\n        self.autoencoder_mode = autoencoder_mode\n\n    def init_hidden_states(self, n_nodes):\n        h0 = []\n        for l in range(self.n_layers):\n            std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float))\n            vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes))\n            h0.append(nn.Parameter(vals))\n        return nn.ParameterList(h0)\n\n    def get_h0(self, x):\n        if self.h0 is not None:\n            return [h.expand(x.shape[0], -1, -1) for h in self.h0]\n        return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers\n\n    def update_state(self, x, h, adj):\n        rnn_in = x\n        for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)):\n            rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj))\n            if self.dropout is not None and layer < (self.n_layers - 1):\n                rnn_in = self.dropout(rnn_in)\n        return h\n\n    def forward(self, x, adj, mask=None, u=None, h=None):\n        # x:[batch, features, nodes, steps]\n        *_, steps = x.size()\n\n        # infer all valid if mask is None\n        if mask is None:\n            mask = torch.ones_like(x, dtype=torch.uint8)\n\n        # init hidden state using node embedding or the empty state\n        if h is None:\n            h = self.get_h0(x)\n        elif not isinstance(h, list):\n            h = [*h]\n\n        # Temporal conv\n        predictions, states = [], []\n        for step in range(steps):\n            x_s = x[..., step]\n            m_s = mask[..., step]\n            h_s = h[-1]\n            u_s = u[..., step] if u is not None else None\n            # impute missing values with predictions from state\n            x_s_hat = self.pred_readout(h_s)\n            # store imputations and state\n            predictions.append(x_s_hat)\n            states.append(torch.stack(h, dim=0))\n            # fill missing values in input with prediction\n            x_s = torch.where(m_s, x_s, x_s_hat)\n            inputs = [x_s, m_s]\n            if u_s is not None:\n                inputs.append(u_s)\n            inputs = torch.cat(inputs, dim=1)  # x_hat complemented + mask + exogenous\n            # update state with original sequence filled using imputations\n            h = self.update_state(inputs, h, adj)\n\n        # In autoencoder mode use states after input processing\n        if self.autoencoder_mode:\n            states = states[1:] + [torch.stack(h, dim=0)]\n\n        # Aggregate outputs -> [batch, features, nodes, steps]\n        predictions = torch.stack(predictions, dim=-1)\n        states = torch.stack(states, dim=-1)\n\n        return predictions, states\n"
  },
  {
    "path": "lib/nn/layers/rits.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom torch.nn.parameter import Parameter\n\nfrom ..utils.ops import reverse_tensor\n\n\nclass FeatureRegression(nn.Module):\n    def __init__(self, input_size):\n        super(FeatureRegression, self).__init__()\n        self.W = Parameter(torch.Tensor(input_size, input_size))\n        self.b = Parameter(torch.Tensor(input_size))\n\n        m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size)\n        self.register_buffer('m', m)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1. / math.sqrt(self.W.shape[0])\n        self.W.data.uniform_(-stdv, stdv)\n        if self.b is not None:\n            self.b.data.uniform_(-stdv, stdv)\n\n    def forward(self, x):\n        z_h = F.linear(x, self.W * Variable(self.m), self.b)\n        return z_h\n\n\nclass TemporalDecay(nn.Module):\n    def __init__(self, d_in, d_out, diag=False):\n        super(TemporalDecay, self).__init__()\n        self.diag = diag\n        self.W = Parameter(torch.Tensor(d_out, d_in))\n        self.b = Parameter(torch.Tensor(d_out))\n\n        if self.diag:\n            assert (d_in == d_out)\n            m = torch.eye(d_in, d_in)\n            self.register_buffer('m', m)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1. / math.sqrt(self.W.shape[0])\n        self.W.data.uniform_(-stdv, stdv)\n        if self.b is not None:\n            self.b.data.uniform_(-stdv, stdv)\n\n    @staticmethod\n    def compute_delta(mask, freq=1):\n        delta = torch.zeros_like(mask).float()\n        one_step = torch.tensor(freq, dtype=delta.dtype, device=delta.device)\n        for i in range(1, delta.shape[-2]):\n            m = mask[..., i - 1, :]\n            delta[..., i, :] = m * one_step + (1 - m) * torch.add(delta[..., i - 1, :], freq)\n        return delta\n\n    def forward(self, d):\n        if self.diag:\n            gamma = F.relu(F.linear(d, self.W * Variable(self.m), self.b))\n        else:\n            gamma = F.relu(F.linear(d, self.W, self.b))\n        gamma = torch.exp(-gamma)\n        return gamma\n\n\nclass RITS(nn.Module):\n    def __init__(self,\n                 input_size,\n                 hidden_size=64):\n        super(RITS, self).__init__()\n        self.input_size = int(input_size)\n        self.hidden_size = int(hidden_size)\n\n        self.rnn_cell = nn.LSTMCell(2 * self.input_size, self.hidden_size)\n\n        self.temp_decay_h = TemporalDecay(d_in=self.input_size, d_out=self.hidden_size, diag=False)\n        self.temp_decay_x = TemporalDecay(d_in=self.input_size, d_out=self.input_size, diag=True)\n\n        self.hist_reg = nn.Linear(self.hidden_size, self.input_size)\n        self.feat_reg = FeatureRegression(self.input_size)\n\n        self.weight_combine = nn.Linear(2 * self.input_size, self.input_size)\n\n    def init_hidden_states(self, x):\n        return Variable(torch.zeros((x.shape[0], self.hidden_size))).to(x.device)\n\n    def forward(self, x, mask=None, delta=None):\n        # x : [batch, steps, features]\n        steps = x.shape[-2]\n\n        if mask is None:\n            mask = torch.ones_like(x, dtype=torch.uint8)\n        if delta is None:\n            delta = TemporalDecay.compute_delta(mask)\n\n        # init rnn states\n        h = self.init_hidden_states(x)\n        c = self.init_hidden_states(x)\n\n        imputation = []\n        predictions = []\n        for step in range(steps):\n            d = delta[:, step, :]\n            m = mask[:, step, :]\n            x_s = x[:, step, :]\n\n            gamma_h = self.temp_decay_h(d)\n\n            # history prediction\n            x_h = self.hist_reg(h)\n            x_c = m * x_s + (1 - m) * x_h\n            h = h * gamma_h\n\n            # feature prediction\n            z_h = self.feat_reg(x_c)\n\n            # predictions combination\n            gamma_x = self.temp_decay_x(d)\n            alpha = self.weight_combine(torch.cat([gamma_x, m], dim=1))\n            alpha = torch.sigmoid(alpha)\n            c_h = alpha * z_h + (1 - alpha) * x_h\n\n            c_c = m * x_s + (1 - m) * c_h\n            inputs = torch.cat([c_c, m], dim=1)\n            h, c = self.rnn_cell(inputs, (h, c))\n\n            imputation.append(c_c)\n            predictions.append(torch.stack((c_h, z_h, x_h), dim=0))\n\n        # imputation -> [batch, steps, features]\n        imputation = torch.stack(imputation, dim=-2)\n        # predictions -> [predictions, batch, steps, features]\n        predictions = torch.stack(predictions, dim=-2)\n        c_h, z_h, x_h = predictions\n\n        return imputation, (c_h, z_h, x_h)\n\n\nclass BRITS(nn.Module):\n\n    def __init__(self, input_size, hidden_size):\n        super().__init__()\n        self.rits_fwd = RITS(input_size, hidden_size)\n        self.rits_bwd = RITS(input_size, hidden_size)\n\n    def forward(self, x, mask=None):\n        # x: [batches, steps, features]\n        # forward\n        imp_fwd, pred_fwd = self.rits_fwd(x, mask)\n        # backward\n        x_bwd = reverse_tensor(x, axis=1)\n        mask_bwd = reverse_tensor(mask, axis=1) if mask is not None else None\n        imp_bwd, pred_bwd = self.rits_bwd(x_bwd, mask_bwd)\n        imp_bwd, pred_bwd = reverse_tensor(imp_bwd, axis=1), [reverse_tensor(pb, axis=1) for pb in pred_bwd]\n        # stack into shape = [batch, directions, steps, features]\n        imputation = torch.stack([imp_fwd, imp_bwd], dim=1)\n        predictions = [torch.stack([pf, pb], dim=1) for pf, pb in zip(pred_fwd, pred_bwd)]\n        c_h, z_h, x_h = predictions\n\n        return imputation, (c_h, z_h, x_h)\n\n    @staticmethod\n    def consistency_loss(imp_fwd, imp_bwd):\n        loss = 0.1 * torch.abs(imp_fwd - imp_bwd).mean()\n        return loss\n"
  },
  {
    "path": "lib/nn/layers/spatial_attention.py",
    "content": "import torch.nn as nn\n\nfrom einops import rearrange\n\n\nclass SpatialAttention(nn.Module):\n    def __init__(self, d_in, d_model, nheads, dropout=0.):\n        super(SpatialAttention, self).__init__()\n        self.lin_in = nn.Linear(d_in, d_model)\n        self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout)\n\n    def forward(self, x, att_mask=None, **kwargs):\n        r\"\"\"Pass the input through the encoder layer.\n\n        Args:\n            src: the sequence to the encoder layer (required).\n            src_mask: the mask for the src sequence (optional).\n            src_key_padding_mask: the mask for the src keys per batch (optional).\n\n        Shape:\n            see the docs in Transformer class.\n        \"\"\"\n        b, s, n, f = x.size()\n        x = rearrange(x, 'b s n f -> n (b s) f')\n        x = self.lin_in(x)\n        x = self.self_attn(x, x, x, attn_mask=att_mask)[0]\n        x = rearrange(x, 'n (b s) f -> b s n f', b=b, s=s)\n        return x\n"
  },
  {
    "path": "lib/nn/layers/spatial_conv.py",
    "content": "import torch\nfrom torch import nn\n\nfrom ... import epsilon\n\n\nclass SpatialConvOrderK(nn.Module):\n    \"\"\"\n    Spatial convolution of order K with possibly different diffusion matrices (useful for directed graphs)\n\n    Efficient implementation inspired from graph-wavenet codebase\n    \"\"\"\n\n    def __init__(self, c_in, c_out, support_len=3, order=2, include_self=True):\n        super(SpatialConvOrderK, self).__init__()\n        self.include_self = include_self\n        c_in = (order * support_len + (1 if include_self else 0)) * c_in\n        self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1)\n        self.order = order\n\n    @staticmethod\n    def compute_support(adj, device=None):\n        if device is not None:\n            adj = adj.to(device)\n        adj_bwd = adj.T\n        adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon)\n        adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon)\n        support = [adj_fwd, adj_bwd]\n        return support\n\n    @staticmethod\n    def compute_support_orderK(adj, k, include_self=False, device=None):\n        if isinstance(adj, (list, tuple)):\n            support = adj\n        else:\n            support = SpatialConvOrderK.compute_support(adj, device)\n        supp_k = []\n        for a in support:\n            ak = a\n            for i in range(k - 1):\n                ak = torch.matmul(ak, a.T)\n                if not include_self:\n                    ak.fill_diagonal_(0.)\n                supp_k.append(ak)\n        return support + supp_k\n\n    def forward(self, x, support):\n        # [batch, features, nodes, steps]\n        if x.dim() < 4:\n            squeeze = True\n            x = torch.unsqueeze(x, -1)\n        else:\n            squeeze = False\n        out = [x] if self.include_self else []\n        if (type(support) is not list):\n            support = [support]\n        for a in support:\n            x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous()\n            out.append(x1)\n            for k in range(2, self.order + 1):\n                x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous()\n                out.append(x2)\n                x1 = x2\n\n        out = torch.cat(out, dim=1)\n        out = self.mlp(out)\n        if squeeze:\n            out = out.squeeze(-1)\n        return out\n"
  },
  {
    "path": "lib/nn/models/__init__.py",
    "content": "from .grin import GRINet\nfrom .brits import BRITSNet\nfrom .mpgru import MPGRUNet, BiMPGRUNet\nfrom .var import VARImputer\nfrom .rgain import RGAINNet\nfrom .rnn_imputers import BiRNNImputer, RNNImputer\n"
  },
  {
    "path": "lib/nn/models/brits.py",
    "content": "import torch\nfrom torch import nn\n\nfrom ..layers import BRITS\n\n\nclass BRITSNet(nn.Module):\n    def __init__(self,\n                 d_in,\n                 d_hidden=64):\n        super(BRITSNet, self).__init__()\n        self.birits = BRITS(input_size=d_in,\n                            hidden_size=d_hidden)\n\n    def forward(self, x, mask=None, **kwargs):\n        # x: [batches, steps, features]\n        imputations, predictions = self.birits(x, mask=mask)\n        # predictions: [batch, directions, steps, features] x 3\n        out = torch.mean(imputations, dim=1)  # -> [batch, steps, features]\n        predictions = torch.cat(predictions, dim=1)  # -> [batch, directions * n_predictions, steps, features]\n        # reshape\n        imputations = torch.transpose(imputations, 0, 1)  # rearrange(imputations, 'b d s f -> d b s f')\n        predictions = torch.transpose(predictions, 0, 1)  # rearrange(predictions, 'b d s f -> d b s f')\n        return out, imputations, predictions\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--d-in', type=int)\n        parser.add_argument('--d-hidden', type=int, default=64)\n        return parser\n"
  },
  {
    "path": "lib/nn/models/grin.py",
    "content": "import torch\nfrom einops import rearrange\nfrom torch import nn\n\nfrom ..layers import BiGRIL\nfrom ...utils.parser_utils import str_to_bool\n\n\nclass GRINet(nn.Module):\n    def __init__(self,\n                 adj,\n                 d_in,\n                 d_hidden,\n                 d_ff,\n                 ff_dropout,\n                 n_layers=1,\n                 kernel_size=2,\n                 decoder_order=1,\n                 global_att=False,\n                 d_u=0,\n                 d_emb=0,\n                 layer_norm=False,\n                 merge='mlp',\n                 impute_only_holes=True):\n        super(GRINet, self).__init__()\n        self.d_in = d_in\n        self.d_hidden = d_hidden\n        self.d_u = int(d_u) if d_u is not None else 0\n        self.d_emb = int(d_emb) if d_emb is not None else 0\n        self.register_buffer('adj', torch.tensor(adj).float())\n        self.impute_only_holes = impute_only_holes\n\n        self.bigrill = BiGRIL(input_size=self.d_in,\n                              ff_size=d_ff,\n                              ff_dropout=ff_dropout,\n                              hidden_size=self.d_hidden,\n                              embedding_size=self.d_emb,\n                              n_nodes=self.adj.shape[0],\n                              n_layers=n_layers,\n                              kernel_size=kernel_size,\n                              decoder_order=decoder_order,\n                              global_att=global_att,\n                              u_size=self.d_u,\n                              layer_norm=layer_norm,\n                              merge=merge)\n\n    def forward(self, x, mask=None, u=None, **kwargs):\n        # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]\n        x = rearrange(x, 'b s n c -> b c n s')\n        if mask is not None:\n            mask = rearrange(mask, 'b s n c -> b c n s')\n\n        if u is not None:\n            u = rearrange(u, 'b s n c -> b c n s')\n\n        # imputation: [batches, channels, nodes, steps] prediction: [4, batches, channels, nodes, steps]\n        imputation, prediction = self.bigrill(x, self.adj, mask=mask, u=u, cached_support=self.training)\n        # In evaluation stage impute only missing values\n        if self.impute_only_holes and not self.training:\n            imputation = torch.where(mask, x, imputation)\n        # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]\n        imputation = torch.transpose(imputation, -3, -1)\n        prediction = torch.transpose(prediction, -3, -1)\n        if self.training:\n            return imputation, prediction\n        return imputation\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--d-hidden', type=int, default=64)\n        parser.add_argument('--d-ff', type=int, default=64)\n        parser.add_argument('--ff-dropout', type=int, default=0.)\n        parser.add_argument('--n-layers', type=int, default=1)\n        parser.add_argument('--kernel-size', type=int, default=2)\n        parser.add_argument('--decoder-order', type=int, default=1)\n        parser.add_argument('--d-u', type=int, default=0)\n        parser.add_argument('--d-emb', type=int, default=8)\n        parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)\n        parser.add_argument('--global-att', type=str_to_bool, nargs='?', const=True, default=False)\n        parser.add_argument('--merge', type=str, default='mlp')\n        parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)\n        return parser\n"
  },
  {
    "path": "lib/nn/models/mpgru.py",
    "content": "import torch\nfrom einops import rearrange\nfrom torch import nn\n\nfrom ..layers import MPGRUImputer, SpatialConvOrderK\nfrom ..utils.ops import reverse_tensor\nfrom ...utils.parser_utils import str_to_bool\n\n\nclass MPGRUNet(nn.Module):\n    def __init__(self,\n                 adj,\n                 d_in,\n                 d_hidden,\n                 d_ff=0,\n                 d_u=0,\n                 n_layers=1,\n                 dropout=0.,\n                 kernel_size=2,\n                 support_len=2,\n                 layer_norm=False,\n                 impute_only_holes=True):\n        super(MPGRUNet, self).__init__()\n        self.register_buffer('adj', torch.tensor(adj).float())\n        n_nodes = adj.shape[0]\n        self.gcgru = MPGRUImputer(input_size=d_in,\n                                  hidden_size=d_hidden,\n                                  ff_size=d_ff,\n                                  u_size=d_u,\n                                  n_layers=n_layers,\n                                  dropout=dropout,\n                                  kernel_size=kernel_size,\n                                  support_len=support_len,\n                                  layer_norm=layer_norm,\n                                  n_nodes=n_nodes)\n        self.impute_only_holes = impute_only_holes\n\n    def forward(self, x, mask=None, u=None, h=None):\n        # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]\n        x = rearrange(x, 'b s n c -> b c n s')\n        if mask is not None:\n            mask = rearrange(mask, 'b s n c -> b c n s')\n        if u is not None:\n            u = rearrange(u, 'b s n c -> b c n s')\n\n        adj = SpatialConvOrderK.compute_support(self.adj, x.device)\n        imputation, _ = self.gcgru(x, adj, mask=mask, u=u, h=h)\n\n        # In evaluation stage impute only missing values\n        if self.impute_only_holes and not self.training:\n            imputation = torch.where(mask, x, imputation)\n\n        # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]\n        imputation = rearrange(imputation, 'b c n s -> b s n c')\n\n        return imputation\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--d-hidden', type=int, default=64)\n        parser.add_argument('--d-ff', type=int, default=64)\n        parser.add_argument('--n-layers', type=int, default=1)\n        parser.add_argument('--kernel-size', type=int, default=2)\n        parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)\n        parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)\n        parser.add_argument('--dropout', type=float, default=0.)\n        return parser\n\n\nclass BiMPGRUNet(nn.Module):\n    def __init__(self,\n                 adj,\n                 d_in,\n                 d_hidden,\n                 d_ff=0,\n                 d_u=0,\n                 n_layers=1,\n                 dropout=0.,\n                 kernel_size=2,\n                 support_len=2,\n                 layer_norm=False,\n                 embedding_size=0,\n                 merge='mlp',\n                 impute_only_holes=True,\n                 autoencoder_mode=False):\n        super(BiMPGRUNet, self).__init__()\n        self.register_buffer('adj', torch.tensor(adj).float())\n        n_nodes = adj.shape[0]\n        self.gcgru_fwd = MPGRUImputer(input_size=d_in,\n                                      hidden_size=d_hidden,\n                                      u_size=d_u,\n                                      n_layers=n_layers,\n                                      dropout=dropout,\n                                      kernel_size=kernel_size,\n                                      support_len=support_len,\n                                      layer_norm=layer_norm,\n                                      n_nodes=n_nodes,\n                                      autoencoder_mode=autoencoder_mode)\n        self.gcgru_bwd = MPGRUImputer(input_size=d_in,\n                                      hidden_size=d_hidden,\n                                      u_size=d_u,\n                                      n_layers=n_layers,\n                                      dropout=dropout,\n                                      kernel_size=kernel_size,\n                                      support_len=support_len,\n                                      layer_norm=layer_norm,\n                                      n_nodes=n_nodes,\n                                      autoencoder_mode=autoencoder_mode)\n        self.impute_only_holes = impute_only_holes\n\n        if n_nodes is None:\n            embedding_size = 0\n        if embedding_size > 0:\n            self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes))\n            nn.init.kaiming_normal_(self.emb, nonlinearity='relu')\n        else:\n            self.register_parameter('emb', None)\n\n        if merge == 'mlp':\n            self._impute_from_states = True\n            self.out = nn.Sequential(\n                nn.Conv2d(in_channels=2 * d_hidden + d_in + embedding_size,\n                          out_channels=d_ff, kernel_size=1),\n                nn.ReLU(),\n                nn.Conv2d(in_channels=d_ff, out_channels=d_in, kernel_size=1)\n            )\n        elif merge in ['mean', 'sum', 'min', 'max']:\n            self._impute_from_states = False\n            self.out = getattr(torch, merge)\n        else:\n            raise ValueError(\"Merge option %s not allowed.\" % merge)\n\n    def forward(self, x, mask=None, u=None, h=None):\n        # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]\n        x = rearrange(x, 'b s n c -> b c n s')\n        if mask is not None:\n            mask = rearrange(mask, 'b s n c -> b c n s')\n        if u is not None:\n            u = rearrange(u, 'b s n c -> b c n s')\n\n        adj = SpatialConvOrderK.compute_support(self.adj, x.device)\n\n        # Forward\n        fwd_pred, fwd_states = self.gcgru_fwd(x, adj, mask=mask, u=u)\n        # Backward\n        rev_x, rev_mask, rev_u = [reverse_tensor(tens, axis=-1) for tens in (x, mask, u)]\n        bwd_res = self.gcgru_bwd(rev_x, adj, mask=rev_mask, u=rev_u)\n        bwd_pred, bwd_states = [reverse_tensor(res, axis=-1) for res in bwd_res]\n\n        if self._impute_from_states:\n            inputs = [fwd_states[-1], bwd_states[-1], mask]  # take only state of last gcgru layer\n            if self.emb is not None:\n                b, *_, s = x.shape  # fwd_h: [batches, channels, nodes, steps]\n                inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)]  # stack emb for batches and steps\n            imputation = torch.cat(inputs, dim=1)\n            imputation = self.out(imputation)\n        else:\n            imputation = torch.stack([fwd_pred, bwd_pred], dim=1)\n            imputation = self.out(imputation, dim=1)\n\n        # In evaluation stage impute only missing values\n        if self.impute_only_holes and not self.training:\n            imputation = torch.where(mask, x, imputation)\n\n        # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]\n        imputation = rearrange(imputation, 'b c n s -> b s n c')\n\n        return imputation\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--d-hidden', type=int, default=64)\n        parser.add_argument('--d-ff', type=int, default=64)\n        parser.add_argument('--n-layers', type=int, default=1)\n        parser.add_argument('--kernel-size', type=int, default=2)\n        parser.add_argument('--d-emb', type=int, default=8)\n        parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)\n        parser.add_argument('--merge', type=str, default='mlp')\n        parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)\n        parser.add_argument('--dropout', type=float, default=0.)\n        parser.add_argument('--autoencoder-mode', type=str_to_bool, nargs='?', const=True, default=False)\n        return parser\n"
  },
  {
    "path": "lib/nn/models/rgain.py",
    "content": "import torch\nfrom torch import nn\n\nfrom .rnn_imputers import BiRNNImputer\nfrom ...utils.parser_utils import str_to_bool\n\n\nclass Generator(nn.Module):\n    def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=True):\n        super(Generator, self).__init__()\n        self.inject_noise = inject_noise\n        self.d_z = d_z if inject_noise else 0\n        self.birnn = BiRNNImputer(d_in,\n                                  d_model,\n                                  d_u=d_z,\n                                  concat_mask=True,\n                                  detach_inputs=False,\n                                  dropout=dropout,\n                                  state_init='zero')\n\n    def forward(self, x, mask):\n        if self.inject_noise:\n            z = torch.rand(x.size(0), x.size(1), self.d_z, device=x.device) * 0.1\n        else:\n            z = None\n        return self.birnn(x, mask, u=z)\n\n\nclass Discriminator(torch.nn.Module):\n    def __init__(self, d_in, d_model, dropout=0.):\n        super(Discriminator, self).__init__()\n        self.birnn = nn.GRU(2 * d_in, d_model, bidirectional=True, batch_first=True)\n        self.dropout = nn.Dropout(dropout)\n        self.read_out = nn.Linear(2 * d_model, d_in)\n\n    def forward(self, x, h):\n        x_in = torch.cat([x, h], dim=-1)\n        out, _ = self.birnn(x_in)\n        logits = self.read_out(self.dropout(out))\n        return logits\n\n\nclass RGAINNet(torch.nn.Module):\n    def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=False, k=5):\n        super(RGAINNet, self).__init__()\n        self.inject_noise = inject_noise\n        self.k = k\n        self.generator = Generator(d_in, d_model, d_z=d_z, dropout=dropout, inject_noise=inject_noise)\n        self.discriminator = Discriminator(d_in, d_model, dropout)\n\n    def forward(self, x, mask, **kwargs):\n        if not self.training and self.inject_noise:\n            res = []\n            for _ in range(self.k):\n                res.append(self.generator(x, mask)[0])\n            return torch.stack(res, 0).mean(0),\n\n        return self.generator(x, mask)\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--d-in', type=int)\n        parser.add_argument('--d-model', type=int, default=None)\n        parser.add_argument('--d-z', type=int, default=8)\n        parser.add_argument('--k', type=int, default=5)\n        parser.add_argument('--inject-noise', type=str_to_bool, nargs='?', const=True, default=False)\n        parser.add_argument('--dropout', type=float, default=0.)\n        return parser\n"
  },
  {
    "path": "lib/nn/models/rnn_imputers.py",
    "content": "import torch\nfrom torch import nn\n\nfrom ..utils.ops import reverse_tensor\n\n\nclass RNNImputer(nn.Module):\n    \"\"\"Fill the blanks with a 1-step-ahead GRU predictor.\"\"\"\n\n    def __init__(self, d_in, d_model, concat_mask=True, detach_inputs=False, state_init='zero', d_u=0):\n        super(RNNImputer, self).__init__()\n        self.concat_mask = concat_mask\n        self.detach_inputs = detach_inputs\n        self.state_init = state_init\n        self.d_model = d_model\n        self.input_dim = d_in + d_u if not concat_mask else 2 * d_in + d_u\n        self.rnn_cell = nn.GRUCell(self.input_dim, d_model)\n        self.read_out = nn.Linear(d_model, d_in)\n\n    def init_hidden_state(self, x):\n        if self.state_init == 'zero':\n            return torch.zeros((x.size(0), self.d_model), device=x.device, dtype=x.dtype)\n        if self.state_init == 'noise':\n            return torch.randn(x.size(0), self.d_model, device=x.device, dtype=x.dtype)\n\n    def _preprocess_input(self, x, x_hat, m, u):\n        if self.detach_inputs:\n            x_p = torch.where(m, x, x_hat.detach())\n        else:\n            x_p = torch.where(m, x, x_hat)\n\n        if u is not None:\n            x_p = torch.cat([x_p, u], -1)\n        if self.concat_mask:\n            x_p = torch.cat([x_p, m], -1)\n        return x_p\n\n    def forward(self, x, mask, u=None, return_hidden=False):\n        # x: [batches, steps, features]\n        steps = x.size(1)\n        # ensure masked values are not visible\n        x = torch.where(mask, x, torch.zeros_like(x))\n\n        h = self.init_hidden_state(x)\n        x_hat = self.read_out(h)\n        hs = [h]\n        preds = [x_hat]\n        for s in range(steps - 1):\n            u_t = None if u is None else u[:, s]\n            x_t = self._preprocess_input(x[:, s], x_hat, mask[:, s], u_t)\n            h = self.rnn_cell(x_t, h)\n            x_hat = self.read_out(h)\n            hs.append(h)\n            preds.append(x_hat)\n\n        x_hat = torch.stack(preds, 1)\n        h = torch.stack(hs, 1)\n        if return_hidden:\n            return x_hat, h\n        return x_hat\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--d-in', type=int)\n        parser.add_argument('--d-model', type=int, default=None)\n        return parser\n\n\nclass BiRNNImputer(nn.Module):\n    \"\"\"Fill the blanks with a 1-step-ahead GRU predictor.\"\"\"\n\n    def __init__(self, d_in, d_model, dropout=0., concat_mask=True, detach_inputs=False, state_init='zero', d_u=0):\n        super(BiRNNImputer, self).__init__()\n        self.d_model = d_model\n        self.fwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init,\n                                  d_u=d_u)\n        self.bwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init,\n                                  d_u=d_u)\n        self.dropout = nn.Dropout(dropout)\n        self.read_out = nn.Linear(2 * d_model, d_in)\n\n    def forward(self, x, mask, u=None, return_hidden=False):\n        # x: [batches, steps, features]\n        x_hat_fwd, h_fwd = self.fwd_rnn(x, mask, u=u, return_hidden=True)\n        x_hat_bwd, h_bwd = self.bwd_rnn(reverse_tensor(x, 1),\n                                        reverse_tensor(mask, 1),\n                                        u=reverse_tensor(u, 1) if u is not None else None,\n                                        return_hidden=True)\n        x_hat_bwd = reverse_tensor(x_hat_bwd, 1)\n        h_bwd = reverse_tensor(h_bwd, 1)\n        h = self.dropout(torch.cat([h_fwd, h_bwd], -1))\n        x_hat = self.read_out(h)\n        if return_hidden:\n            return (x_hat, x_hat_fwd, x_hat_bwd), h\n        return x_hat, x_hat_fwd, x_hat_bwd\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--d-in', type=int)\n        parser.add_argument('--d-model', type=int, default=None)\n        parser.add_argument('--dropout', type=float, default=0.)\n        return parser\n"
  },
  {
    "path": "lib/nn/models/var.py",
    "content": "import torch\nfrom einops import rearrange\nfrom torch import nn\n\nfrom lib import epsilon\n\n\nclass VAR(nn.Module):\n    def __init__(self, order, d_in, d_out=None, steps_ahead=1, bias=True):\n        super(VAR, self).__init__()\n        self.order = order\n        self.d_in = d_in\n        self.d_out = d_out if d_out is not None else d_in\n        self.steps_ahead = steps_ahead\n        self.lin = nn.Linear(order * d_in, steps_ahead * self.d_out, bias=bias)\n\n    def forward(self, x):\n        # x: [batches, steps, features]\n        x = rearrange(x, 'b s f -> b (s f)')\n        out = self.lin(x)\n        out = rearrange(out, 'b (s f) -> b s f', s=self.steps_ahead, f=self.d_out)\n        return out\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--order', type=int)\n        parser.add_argument('--d-in', type=int)\n        parser.add_argument('--d-out', type=int, default=None)\n        parser.add_argument('--steps-ahead', type=int, default=1)\n        return parser\n\n\nclass VARImputer(nn.Module):\n    \"\"\"Fill the blanks with a 1-step-ahead VAR predictor.\"\"\"\n\n    def __init__(self, order, d_in, padding='mean'):\n        super(VARImputer, self).__init__()\n        assert padding in ['mean', 'zero']\n        self.order = order\n        self.padding = padding\n        self.predictor = VAR(order, d_in, d_out=d_in, steps_ahead=1)\n\n    def forward(self, x, mask=None):\n        # x: [batches, steps, features]\n        batch_size, steps, n_feats = x.shape\n        if mask is None:\n            mask = torch.ones_like(x, dtype=torch.uint8)\n        x = x * mask\n        # pad input sequence to start filling from first step\n        if self.padding == 'mean':\n            mean = torch.sum(x, 1) / (torch.sum(mask, 1) + epsilon)\n            pad = torch.repeat_interleave(mean.unsqueeze(1), self.order, 1)\n        elif self.padding == 'zero':\n            pad = torch.zeros((batch_size, self.order, n_feats)).to(x.device)\n        x = torch.cat([pad, x], 1)\n        # x: [batch, order + steps, features]\n        x = [x[:, i] for i in range(x.shape[1])]\n        for s in range(steps):\n            x_hat = self.predictor(torch.stack(x[s:s + self.order], 1))\n            x_hat = x_hat[:, 0]\n            x[s + self.order] = torch.where(mask[:, s], x[s + self.order], x_hat)\n        x = torch.stack(x[self.order:], 1)  # remove padding\n        return x\n\n    @staticmethod\n    def add_model_specific_args(parser):\n        parser.add_argument('--order', type=int)\n        parser.add_argument('--d-in', type=int)\n        parser.add_argument(\"--padding\", type=str, default='mean')\n        return parser\n"
  },
  {
    "path": "lib/nn/utils/__init__.py",
    "content": ""
  },
  {
    "path": "lib/nn/utils/metric_base.py",
    "content": "from functools import partial\n\nimport torch\nfrom pytorch_lightning.metrics import Metric\nfrom torchmetrics.utilities.checks import _check_same_shape\n\n\nclass MaskedMetric(Metric):\n    def __init__(self,\n                 metric_fn,\n                 mask_nans=False,\n                 mask_inf=False,\n                 compute_on_step=True,\n                 dist_sync_on_step=False,\n                 process_group=None,\n                 dist_sync_fn=None,\n                 metric_kwargs=None,\n                 at=None):\n        super(MaskedMetric, self).__init__(compute_on_step=compute_on_step,\n                                           dist_sync_on_step=dist_sync_on_step,\n                                           process_group=process_group,\n                                           dist_sync_fn=dist_sync_fn)\n\n        if metric_kwargs is None:\n            metric_kwargs = dict()\n        self.metric_fn = partial(metric_fn, **metric_kwargs)\n        self.mask_nans = mask_nans\n        self.mask_inf = mask_inf\n        if at is None:\n            self.at = slice(None)\n        else:\n            self.at = slice(at, at + 1)\n        self.add_state('value', dist_reduce_fx='sum', default=torch.tensor(0.).float())\n        self.add_state('numel', dist_reduce_fx='sum', default=torch.tensor(0))\n\n    def _check_mask(self, mask, val):\n        if mask is None:\n            mask = torch.ones_like(val).byte()\n        else:\n            _check_same_shape(mask, val)\n        if self.mask_nans:\n            mask = mask * ~torch.isnan(val)\n        if self.mask_inf:\n            mask = mask * ~torch.isinf(val)\n        return mask\n\n    def _compute_masked(self, y_hat, y, mask):\n        _check_same_shape(y_hat, y)\n        val = self.metric_fn(y_hat, y)\n        mask = self._check_mask(mask, val)\n        val = torch.where(mask, val, torch.tensor(0., device=val.device).float())\n        return val.sum(), mask.sum()\n\n    def _compute_std(self, y_hat, y):\n        _check_same_shape(y_hat, y)\n        val = self.metric_fn(y_hat, y)\n        return val.sum(), val.numel()\n\n    def is_masked(self, mask):\n        return self.mask_inf or self.mask_nans or (mask is not None)\n\n    def update(self, y_hat, y, mask=None):\n        y_hat = y_hat[:, self.at]\n        y = y[:, self.at]\n        if mask is not None:\n            mask = mask[:, self.at]\n        if self.is_masked(mask):\n            val, numel = self._compute_masked(y_hat, y, mask)\n        else:\n            val, numel = self._compute_std(y_hat, y)\n        self.value += val\n        self.numel += numel\n\n    def compute(self):\n        if self.numel > 0:\n            return self.value / self.numel\n        return self.value\n"
  },
  {
    "path": "lib/nn/utils/metrics.py",
    "content": "from .metric_base import MaskedMetric\nfrom .ops import mape\nfrom torch.nn import functional as F\nimport torch\n\nfrom torchmetrics.utilities.checks import _check_same_shape\n\nfrom ... import epsilon\n\n\nclass MaskedMAE(MaskedMetric):\n    def __init__(self,\n                 mask_nans=False,\n                 mask_inf=False,\n                 compute_on_step=True,\n                 dist_sync_on_step=False,\n                 process_group=None,\n                 dist_sync_fn=None,\n                 at=None):\n        super(MaskedMAE, self).__init__(metric_fn=F.l1_loss,\n                                        mask_nans=mask_nans,\n                                        mask_inf=mask_inf,\n                                        compute_on_step=compute_on_step,\n                                        dist_sync_on_step=dist_sync_on_step,\n                                        process_group=process_group,\n                                        dist_sync_fn=dist_sync_fn,\n                                        metric_kwargs={'reduction': 'none'},\n                                        at=at)\n\n\nclass MaskedMAPE(MaskedMetric):\n    def __init__(self,\n                 mask_nans=False,\n                 compute_on_step=True,\n                 dist_sync_on_step=False,\n                 process_group=None,\n                 dist_sync_fn=None,\n                 at=None):\n        super(MaskedMAPE, self).__init__(metric_fn=mape,\n                                         mask_nans=mask_nans,\n                                         mask_inf=True,\n                                         compute_on_step=compute_on_step,\n                                         dist_sync_on_step=dist_sync_on_step,\n                                         process_group=process_group,\n                                         dist_sync_fn=dist_sync_fn,\n                                         at=at)\n\n\nclass MaskedMSE(MaskedMetric):\n    def __init__(self,\n                 mask_nans=False,\n                 compute_on_step=True,\n                 dist_sync_on_step=False,\n                 process_group=None,\n                 dist_sync_fn=None,\n                 at=None):\n        super(MaskedMSE, self).__init__(metric_fn=F.mse_loss,\n                                        mask_nans=mask_nans,\n                                        mask_inf=True,\n                                        compute_on_step=compute_on_step,\n                                        dist_sync_on_step=dist_sync_on_step,\n                                        process_group=process_group,\n                                        dist_sync_fn=dist_sync_fn,\n                                        metric_kwargs={'reduction': 'none'},\n                                        at=at)\n\n\nclass MaskedMRE(MaskedMetric):\n    def __init__(self,\n                 mask_nans=False,\n                 mask_inf=False,\n                 compute_on_step=True,\n                 dist_sync_on_step=False,\n                 process_group=None,\n                 dist_sync_fn=None,\n                 at=None):\n        super(MaskedMRE, self).__init__(metric_fn=F.l1_loss,\n                                        mask_nans=mask_nans,\n                                        mask_inf=mask_inf,\n                                        compute_on_step=compute_on_step,\n                                        dist_sync_on_step=dist_sync_on_step,\n                                        process_group=process_group,\n                                        dist_sync_fn=dist_sync_fn,\n                                        metric_kwargs={'reduction': 'none'},\n                                        at=at)\n        self.add_state('tot', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float))\n\n    def _compute_masked(self, y_hat, y, mask):\n        _check_same_shape(y_hat, y)\n        val = self.metric_fn(y_hat, y)\n        mask = self._check_mask(mask, val)\n        val = torch.where(mask, val, torch.tensor(0., device=y.device, dtype=torch.float))\n        y_masked = torch.where(mask, y, torch.tensor(0., device=y.device, dtype=torch.float))\n        return val.sum(), mask.sum(), y_masked.sum()\n\n    def _compute_std(self, y_hat, y):\n        _check_same_shape(y_hat, y)\n        val = self.metric_fn(y_hat, y)\n        return val.sum(), val.numel(), y.sum()\n\n    def compute(self):\n        if self.tot > epsilon:\n            return self.value / self.tot\n        return self.value\n\n    def update(self, y_hat, y, mask=None):\n        y_hat = y_hat[:, self.at]\n        y = y[:, self.at]\n        if mask is not None:\n            mask = mask[:, self.at]\n        if self.is_masked(mask):\n            val, numel, tot = self._compute_masked(y_hat, y, mask)\n        else:\n            val, numel, tot = self._compute_std(y_hat, y)\n        self.value += val\n        self.numel += numel\n        self.tot += tot\n\n\n"
  },
  {
    "path": "lib/nn/utils/ops.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom einops import reduce\nfrom torch.autograd import Variable\n\nfrom ... import epsilon\n\n\ndef mae(y_hat, y, reduction='none'):\n    return F.l1_loss(y_hat, y, reduction=reduction)\n\n\ndef mape(y_hat, y):\n    return torch.abs((y_hat - y) / y)\n\n\ndef wape_loss(y_hat, y):\n    l = torch.abs(y_hat - y)\n    return l.sum() / (y.sum() + epsilon)\n\n\ndef smape_loss(y_hat, y):\n    c = torch.abs(y) > epsilon\n    l_minus = torch.abs(y_hat - y)\n    l_plus = torch.abs(y_hat + y) + epsilon\n    l = 2 * l_minus / l_plus * c.float()\n    return l.sum() / c.sum()\n\n\ndef peak_prediction_loss(y_hat, y, reduction='none'):\n    y_max = reduce(y, 'b s n 1 -> b 1 n 1', 'max')\n    y_min = reduce(y, 'b s n 1 -> b 1 n 1', 'min')\n    target = torch.cat([y_max, y_min], dim=1)\n    return F.mse_loss(y_hat, target, reduction=reduction)\n\n\ndef wrap_loss_fn(base_loss):\n    def loss_fn(y_hat, y_true, mask=None):\n        scaling = 1.\n        if mask is not None:\n            try:\n                loss = base_loss(y_hat, y_true, reduction='none')\n            except TypeError:\n                loss = base_loss(y_hat, y_true)\n            loss = loss * mask\n            loss = loss.sum() / (mask.sum() + epsilon)\n            # scaling = mask.sum() / torch.numel(mask)\n        else:\n            loss = base_loss(y_hat, y_true).mean()\n        return scaling * loss\n\n    return loss_fn\n\n\ndef rbf_sim(x, gamma, device='cpu'):\n    n = x.size()[0]\n    a = torch.exp(-gamma * F.pdist(x, 2) ** 2)\n    row_idx, col_idx = torch.triu_indices(n, n, 1)\n    A = 0.5 * torch.eye(n, n).to(device)\n    A[row_idx, col_idx] = a\n    return A + A.T\n\n\ndef reverse_tensor(tensor=None, axis=-1):\n    if tensor is None:\n        return None\n    if tensor.dim() <= 1:\n        return tensor\n    indices = range(tensor.size()[axis])[::-1]\n    indices = Variable(torch.LongTensor(indices), requires_grad=False).to(tensor.device)\n    return tensor.index_select(axis, indices)\n"
  },
  {
    "path": "lib/utils/__init__.py",
    "content": "from .utils import *\n"
  },
  {
    "path": "lib/utils/numpy_metrics.py",
    "content": "import numpy as np\n\n\ndef mae(y_hat, y):\n    return np.abs(y_hat - y).mean()\n\n\ndef nmae(y_hat, y):\n    delta = np.max(y) - np.min(y) + 1e-8\n    return mae(y_hat, y) * 100 / delta\n\n\ndef mape(y_hat, y):\n    return 100 * np.abs((y_hat - y) / (y + 1e-8)).mean()\n\n\ndef mse(y_hat, y):\n    return np.square(y_hat - y).mean()\n\n\ndef rmse(y_hat, y):\n    return np.sqrt(mse(y_hat, y))\n\n\ndef nrmse(y_hat, y):\n    delta = np.max(y) - np.min(y) + 1e-8\n    return rmse(y_hat, y) * 100 / delta\n\n\ndef nrmse_2(y_hat, y):\n    nrmse_ = np.sqrt(np.square(y_hat - y).sum() / np.square(y).sum())\n    return nrmse_ * 100\n\n\ndef r2(y_hat, y):\n    return 1. - np.square(y_hat - y).sum() / (np.square(y.mean(0) - y).sum())\n\n\ndef masked_mae(y_hat, y, mask):\n    err = np.abs(y_hat - y) * mask\n    return err.sum() / mask.sum()\n\n\ndef masked_mape(y_hat, y, mask):\n    err = np.abs((y_hat - y) / (y + 1e-8)) * mask\n    return err.sum() / mask.sum()\n\n\ndef masked_mse(y_hat, y, mask):\n    err = np.square(y_hat - y) * mask\n    return err.sum() / mask.sum()\n\n\ndef masked_rmse(y_hat, y, mask):\n    err = np.square(y_hat - y) * mask\n    return np.sqrt(err.sum() / mask.sum())\n\n\ndef masked_mre(y_hat, y, mask):\n    err = np.abs(y_hat - y) * mask\n    return err.sum() / ((y * mask).sum() + 1e-8)\n"
  },
  {
    "path": "lib/utils/parser_utils.py",
    "content": "import inspect\nfrom argparse import Namespace, ArgumentParser\nfrom typing import Union\n\n\ndef str_to_bool(value):\n    if isinstance(value, bool):\n        return value\n    if value.lower() in {'false', 'f', '0', 'no', 'n', 'off'}:\n        return False\n    elif value.lower() in {'true', 't', '1', 'yes', 'y', 'on'}:\n        return True\n    raise ValueError(f'{value} is not a valid boolean value')\n\n\ndef config_dict_from_args(args):\n    \"\"\"\n    Extract a dictionary with the experiment configuration from arguments (necessary to filter TestTube arguments)\n\n    :param args: TTNamespace\n    :return: hyparams dict\n    \"\"\"\n    keys_to_remove = {'hpc_exp_number', 'trials', 'optimize_parallel', 'optimize_parallel_gpu',\n                      'optimize_parallel_cpu', 'generate_trials', 'optimize_trials_parallel_gpu'}\n    hparams = {key: v for key, v in args.__dict__.items() if key not in keys_to_remove}\n    return hparams\n\n\ndef update_from_config(args: Namespace, config: dict):\n    assert set(config.keys()) <= set(vars(args)), f'{set(config.keys()).difference(vars(args))} not in args.'\n    args.__dict__.update(config)\n    return args\n\n\ndef parse_by_group(parser):\n    \"\"\"\n    Create a nested namespace using the groups defined in the argument parser.\n    Adapted from https://stackoverflow.com/a/56631542/6524027\n\n    :param args: arguments\n    :param parser: the parser\n    :return:\n    \"\"\"\n    assert isinstance(parser, ArgumentParser)\n    args = parser.parse_args()\n\n    # the first two argument groups are 'positional_arguments' and 'optional_arguments'\n    pos_group, optional_group = parser._action_groups[0], parser._action_groups[1]\n    args_dict = args._get_kwargs()\n    pos_optional_arg_names = [arg.dest for arg in pos_group._group_actions] + [arg.dest for arg in\n                                                                               optional_group._group_actions]\n    pos_optional_args = {name: value for name, value in args_dict if name in pos_optional_arg_names}\n    other_group_args = dict()\n\n    # If there are additional argument groups, add them as nested namespaces\n    if len(parser._action_groups) > 2:\n        for group in parser._action_groups[2:]:\n            group_arg_names = [arg.dest for arg in group._group_actions]\n            other_group_args[group.title] = Namespace(\n                **{name: value for name, value in args_dict if name in group_arg_names})\n\n    # combine the positiona/optional args and the group args\n    combined_args = pos_optional_args\n    combined_args.update(other_group_args)\n    return Namespace(flat=args, **combined_args)\n\n\ndef filter_args(args: Union[Namespace, dict], target_cls, return_dict=False):\n    argspec = inspect.getfullargspec(target_cls.__init__)\n    target_args = argspec.args\n    if isinstance(args, Namespace):\n        args = vars(args)\n    filtered_args = {k: args[k] for k in target_args if k in args}\n    if return_dict:\n        return filtered_args\n    return Namespace(**filtered_args)\n\n\ndef filter_function_args(args: Union[Namespace, dict], function, return_dict=False):\n    argspec = inspect.getfullargspec(function)\n    target_args = argspec.args\n    if isinstance(args, Namespace):\n        args = vars(args)\n    filtered_args = {k: args[k] for k in target_args if k in args}\n    if return_dict:\n        return filtered_args\n    return Namespace(**filtered_args)\n"
  },
  {
    "path": "lib/utils/utils.py",
    "content": "import numpy as np\nimport pandas as pd\n\nfrom sklearn.metrics.pairwise import haversine_distances\n\n\ndef sample_mask(shape, p=0.002, p_noise=0., max_seq=1, min_seq=1, rng=None):\n    if rng is None:\n        rand = np.random.random\n        randint = np.random.randint\n    else:\n        rand = rng.random\n        randint = rng.integers\n    mask = rand(shape) < p\n    for col in range(mask.shape[1]):\n        idxs = np.flatnonzero(mask[:, col])\n        if not len(idxs):\n            continue\n        fault_len = min_seq\n        if max_seq > min_seq:\n            fault_len = fault_len + int(randint(max_seq - min_seq))\n        idxs_ext = np.concatenate([np.arange(i, i + fault_len) for i in idxs])\n        idxs = np.unique(idxs_ext)\n        idxs = np.clip(idxs, 0, shape[0] - 1)\n        mask[idxs, col] = True\n    mask = mask | (rand(mask.shape) < p_noise)\n    return mask.astype('uint8')\n\n\ndef compute_mean(x, index=None):\n    \"\"\"Compute the mean values for each datetime. The mean is first computed hourly over the week of the year.\n    Further NaN values are computed using hourly mean over the same month through the years. If other NaN are present,\n    they are removed using the mean of the sole hours. Hoping reasonably that there is at least a non-NaN entry of the\n    same hour of the NaN datetime in all the dataset.\"\"\"\n    if isinstance(x, np.ndarray) and index is not None:\n        shape = x.shape\n        x = x.reshape((shape[0], -1))\n        df_mean = pd.DataFrame(x, index=index)\n    else:\n        df_mean = x.copy()\n    cond0 = [df_mean.index.year, df_mean.index.isocalendar().week, df_mean.index.hour]\n    cond1 = [df_mean.index.year, df_mean.index.month, df_mean.index.hour]\n    conditions = [cond0, cond1, cond1[1:], cond1[2:]]\n    while df_mean.isna().values.sum() and len(conditions):\n        nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)\n        df_mean = df_mean.fillna(nan_mean)\n        conditions = conditions[1:]\n    if df_mean.isna().values.sum():\n        df_mean = df_mean.fillna(method='ffill')\n        df_mean = df_mean.fillna(method='bfill')\n    if isinstance(x, np.ndarray):\n        df_mean = df_mean.values.reshape(shape)\n    return df_mean\n\n\ndef geographical_distance(x=None, to_rad=True):\n    \"\"\"\n    Compute the as-the-crow-flies distance between every pair of samples in `x`. The first dimension of each point is\n    assumed to be the latitude, the second is the longitude. The inputs is assumed to be in degrees. If it is not the\n    case, `to_rad` must be set to False. The dimension of the data must be 2.\n\n    Parameters\n    ----------\n    x : pd.DataFrame or np.ndarray\n        array_like structure of shape (n_samples_2, 2).\n    to_rad : bool\n        whether to convert inputs to radians (provided that they are in degrees).\n\n    Returns\n    -------\n    distances :\n        The distance between the points in kilometers.\n    \"\"\"\n    _AVG_EARTH_RADIUS_KM = 6371.0088\n\n    # Extract values of X if it is a DataFrame, else assume it is 2-dim array of lat-lon pairs\n    latlon_pairs = x.values if isinstance(x, pd.DataFrame) else x\n\n    # If the input values are in degrees, convert them in radians\n    if to_rad:\n        latlon_pairs = np.vectorize(np.radians)(latlon_pairs)\n\n    distances = haversine_distances(latlon_pairs) * _AVG_EARTH_RADIUS_KM\n\n    # Cast response\n    if isinstance(x, pd.DataFrame):\n        res = pd.DataFrame(distances, x.index, x.index)\n    else:\n        res = distances\n\n    return res\n\n\ndef infer_mask(df, infer_from='next'):\n    \"\"\"Infer evaluation mask from DataFrame. In the evaluation mask a value is 1 if it is present in the DataFrame and\n    absent in the `infer_from` month.\n\n    @param pd.DataFrame df: the DataFrame.\n    @param str infer_from: denotes from which month the evaluation value must be inferred.\n    Can be either `previous` or `next`.\n    @return: pd.DataFrame eval_mask: the evaluation mask for the DataFrame\n    \"\"\"\n    mask = (~df.isna()).astype('uint8')\n    eval_mask = pd.DataFrame(index=mask.index, columns=mask.columns, data=0).astype('uint8')\n    if infer_from == 'previous':\n        offset = -1\n    elif infer_from == 'next':\n        offset = 1\n    else:\n        raise ValueError('infer_from can only be one of %s' % ['previous', 'next'])\n    months = sorted(set(zip(mask.index.year, mask.index.month)))\n    length = len(months)\n    for i in range(length):\n        j = (i + offset) % length\n        year_i, month_i = months[i]\n        year_j, month_j = months[j]\n        mask_j = mask[(mask.index.year == year_j) & (mask.index.month == month_j)]\n        mask_i = mask_j.shift(1, pd.DateOffset(months=12 * (year_i - year_j) + (month_i - month_j)))\n        mask_i = mask_i[~mask_i.index.duplicated(keep='first')]\n        mask_i = mask_i[np.in1d(mask_i.index, mask.index)]\n        eval_mask.loc[mask_i.index] = ~mask_i.loc[mask_i.index] & mask.loc[mask_i.index]\n    return eval_mask\n\n\ndef prediction_dataframe(y, index, columns=None, aggregate_by='mean'):\n    \"\"\"Aggregate batched predictions in a single DataFrame.\n\n    @param (list or np.ndarray) y: the list of predictions.\n    @param (list or np.ndarray) index: the list of time indexes coupled with the predictions.\n    @param (list or pd.Index) columns: the columns of the returned DataFrame.\n    @param (str or list) aggregate_by: how to aggregate the predictions in case there are more than one for a step.\n    - `mean`: take the mean of the predictions\n    - `central`: take the prediction at the central position, assuming that the predictions are ordered chronologically\n    - `smooth_central`: average the predictions weighted by a gaussian signal with std=1\n    - `last`: take the last prediction\n    @return: pd.DataFrame df: the evaluation mask for the DataFrame\n    \"\"\"\n    dfs = [pd.DataFrame(data=data.reshape(data.shape[:2]), index=idx, columns=columns) for data, idx in zip(y, index)]\n    df = pd.concat(dfs)\n    preds_by_step = df.groupby(df.index)\n    # aggregate according passed methods\n    aggr_methods = ensure_list(aggregate_by)\n    dfs = []\n    for aggr_by in aggr_methods:\n        if aggr_by == 'mean':\n            dfs.append(preds_by_step.mean())\n        elif aggr_by == 'central':\n            dfs.append(preds_by_step.aggregate(lambda x: x[int(len(x) // 2)]))\n        elif aggr_by == 'smooth_central':\n            from scipy.signal import gaussian\n            dfs.append(preds_by_step.aggregate(lambda x: np.average(x, weights=gaussian(len(x), 1))))\n        elif aggr_by == 'last':\n            dfs.append(preds_by_step.aggregate(lambda x: x[0]))  # first imputation has missing value in last position\n        else:\n            raise ValueError('aggregate_by can only be one of %s' % ['mean', 'central' 'smooth_central', 'last'])\n    if isinstance(aggregate_by, str):\n        return dfs[0]\n    return dfs\n\n\ndef ensure_list(obj):\n    if isinstance(obj, (list, tuple)):\n        return list(obj)\n    else:\n        return [obj]\n\n\ndef missing_val_lens(mask):\n    m = np.concatenate([np.zeros((1, mask.shape[1])),\n                        (~mask.astype('bool')).astype('int'),\n                        np.zeros((1, mask.shape[1]))])\n    mdiff = np.diff(m, axis=0)\n    lens = []\n    for c in range(m.shape[1]):\n        mj, = mdiff[:, c].nonzero()\n        diff = np.diff(mj)[::2]\n        lens.extend(list(diff))\n    return lens\n\n\ndef disjoint_months(dataset, months=None, synch_mode='window'):\n    idxs = np.arange(len(dataset))\n    months = ensure_list(months)\n    # divide indices according to window or horizon\n    if synch_mode == 'window':\n        start, end = 0, dataset.window - 1\n    elif synch_mode == 'horizon':\n        start, end = dataset.horizon_offset, dataset.horizon_offset + dataset.horizon - 1\n    else:\n        raise ValueError('synch_mode can only be one of %s' % ['window', 'horizon'])\n    # after idxs\n    start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)\n    end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)\n    idxs_in_months = start_in_months & end_in_months\n    after_idxs = idxs[idxs_in_months]\n    # previous idxs\n    months = np.setdiff1d(np.arange(1, 13), months)\n    start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)\n    end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)\n    idxs_in_months = start_in_months & end_in_months\n    prev_idxs = idxs[idxs_in_months]\n    return prev_idxs, after_idxs\n\n\ndef thresholded_gaussian_kernel(x, theta=None, threshold=None, threshold_on_input=False):\n    if theta is None:\n        theta = np.std(x)\n    weights = np.exp(-np.square(x / theta))\n    if threshold is not None:\n        mask = x > threshold if threshold_on_input else weights < threshold\n        weights[mask] = 0.\n    return weights\n"
  },
  {
    "path": "requirements.txt",
    "content": "einops\nfancyimpute==0.6\nh5py\nopenpyxl\nnumpy\npandas\npytorch-lightning==1.4\npyyaml\nscikit-learn\nscipy\ntables\ntensorboard\ntensorflow==2.5.0\ntensorflow-gpu==2.4.0\ntorch==1.8\ntorchvision\ntorchaudio\ntorchmetrics==0.5\n"
  },
  {
    "path": "scripts/run_baselines.py",
    "content": "from argparse import ArgumentParser\n\nimport numpy as np\nfrom fancyimpute import MatrixFactorization, IterativeImputer\nfrom sklearn.neighbors import kneighbors_graph\n\nfrom lib import datasets\nfrom lib.utils import numpy_metrics\nfrom lib.utils.parser_utils import str_to_bool\n\nmetrics = {\n    'mae': numpy_metrics.masked_mae,\n    'mse': numpy_metrics.masked_mse,\n    'mre': numpy_metrics.masked_mre,\n    'mape': numpy_metrics.masked_mape\n}\n\n\ndef parse_args():\n    parser = ArgumentParser()\n    # experiment setting\n    parser.add_argument('--datasets', nargs='+', type=str, default=['all'])\n    parser.add_argument('--imputers', nargs='+', type=str, default=['all'])\n    parser.add_argument('--n-runs', type=int, default=5)\n    parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=True)\n    # SpatialKNNImputer params\n    parser.add_argument('--k', type=int, default=10)\n    # MFImputer params\n    parser.add_argument('--rank', type=int, default=10)\n    # MICEImputer params\n    parser.add_argument('--mice-iterations', type=int, default=100)\n    parser.add_argument('--mice-n-features', type=int, default=None)\n    args = parser.parse_args()\n    # parse dataset\n    if args.datasets[0] == 'all':\n        args.datasets = ['air36', 'air', 'bay', 'irish', 'la', 'bay_noise', 'irish_noise', 'la_noise']\n    # parse imputers\n    if args.imputers[0] == 'all':\n        args.imputers = ['mean', 'knn', 'mf', 'mice']\n    if not args.in_sample:\n        args.imputers = [name for name in args.imputers if name in ['mean', 'mice']]\n    return args\n\n\nclass Imputer:\n    short_name: str\n\n    def __init__(self, method=None, is_deterministic=True, in_sample=True):\n        self.name = self.__class__.__name__\n        self.method = method\n        self.is_deterministic = is_deterministic\n        self.in_sample = in_sample\n\n    def fit(self, x, mask):\n        if not self.in_sample:\n            x_hat = np.where(mask, x, np.nan)\n            return self.method.fit(x_hat)\n\n    def predict(self, x, mask):\n        x_hat = np.where(mask, x, np.nan)\n        if self.in_sample:\n            return self.method.fit_transform(x_hat)\n        else:\n            return self.method.transform(x_hat)\n\n    def params(self):\n        return dict()\n\n\nclass SpatialKNNImputer(Imputer):\n    short_name = 'knn'\n\n    def __init__(self, adj, k=20):\n        super(SpatialKNNImputer, self).__init__()\n        self.k = k\n        # normalize sim between [0, 1]\n        sim = (adj + adj.min()) / (adj.max() + adj.min())\n        knns = kneighbors_graph(1 - sim,\n                                n_neighbors=self.k,\n                                include_self=False,\n                                metric='precomputed').toarray()\n        self.knns = knns\n\n    def fit(self, x, mask):\n        pass\n\n    def predict(self, x, mask):\n        x = np.where(mask, x, 0)\n        with np.errstate(divide='ignore', invalid='ignore'):\n            y_hat = (x @ self.knns.T) / (mask @ self.knns.T)\n        y_hat[~np.isfinite(y_hat)] = x.mean()\n        return np.where(mask, x, y_hat)\n\n    def params(self):\n        return dict(k=self.k)\n\n\nclass MeanImputer(Imputer):\n    short_name = 'mean'\n\n    def fit(self, x, mask):\n        d = np.where(mask, x, np.nan)\n        self.means = np.nanmean(d, axis=0, keepdims=True)\n\n    def predict(self, x, mask):\n        if self.in_sample:\n            d = np.where(mask, x, np.nan)\n            means = np.nanmean(d, axis=0, keepdims=True)\n        else:\n            means = self.means\n        return np.where(mask, x, means)\n\n\nclass MatrixFactorizationImputer(Imputer):\n    short_name = 'mf'\n\n    def __init__(self, rank=10, loss='mae', verbose=0):\n        method = MatrixFactorization(rank=rank, loss=loss, verbose=verbose)\n        super(MatrixFactorizationImputer, self).__init__(method, is_deterministic=False, in_sample=True)\n\n    def params(self):\n        return dict(rank=self.method.rank)\n\n\nclass MICEImputer(Imputer):\n    short_name = 'mice'\n\n    def __init__(self, max_iter=100, n_nearest_features=None, in_sample=True, verbose=False):\n        method = IterativeImputer(max_iter=max_iter, n_nearest_features=n_nearest_features, verbose=verbose)\n        is_deterministic = n_nearest_features is None\n        super(MICEImputer, self).__init__(method, is_deterministic=is_deterministic, in_sample=in_sample)\n\n    def params(self):\n        return dict(max_iter=self.method.max_iter, k=self.method.n_nearest_features or -1)\n\n\ndef get_dataset(dataset_name):\n    if dataset_name[:3] == 'air':\n        dataset = datasets.AirQuality(impute_nans=True, small=dataset_name[3:] == '36')\n    elif dataset_name == 'bay':\n        dataset = datasets.MissingValuesPemsBay()\n    elif dataset_name == 'la':\n        dataset = datasets.MissingValuesMetrLA()\n    elif dataset_name == 'la_noise':\n        dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=0.25)\n    elif dataset_name == 'bay_noise':\n        dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=0.25)\n    else:\n        raise ValueError(f\"Dataset {dataset_name} not available in this setting.\")\n\n    # split in train/test\n    if isinstance(dataset, datasets.AirQuality):\n        test_slice = np.in1d(dataset.df.index.month, dataset.test_months)\n        train_slice = ~test_slice\n    else:\n        train_slice = np.zeros(len(dataset)).astype(bool)\n        train_slice[:-int(0.2 * len(dataset))] = True\n\n    # integrate back eval values in dataset\n    dataset.eval_mask[train_slice] = 0\n\n    return dataset, train_slice\n\n\ndef get_imputer(imputer_name, args):\n    if imputer_name == 'mean':\n        imputer = MeanImputer(in_sample=args.in_sample)\n    elif imputer_name == 'knn':\n        imputer = SpatialKNNImputer(adj=args.adj, k=args.k)\n    elif imputer_name == 'mf':\n        imputer = MatrixFactorizationImputer(rank=args.rank)\n    elif imputer_name == 'mice':\n        imputer = MICEImputer(max_iter=args.mice_iterations,\n                              n_nearest_features=args.mice_n_features,\n                              in_sample=args.in_sample)\n    else:\n        raise ValueError(f\"Imputer {imputer_name} not available in this setting.\")\n    return imputer\n\n\ndef run(imputer, dataset, train_slice):\n    test_slice = ~train_slice\n    if args.in_sample:\n        x_train, mask_train = dataset.numpy(), dataset.training_mask\n        y_hat = imputer.predict(x_train, mask_train)[test_slice]\n    else:\n        x_train, mask_train = dataset.numpy()[train_slice], dataset.training_mask[train_slice]\n        imputer.fit(x_train, mask_train)\n        x_test, mask_test = dataset.numpy()[test_slice], dataset.training_mask[test_slice]\n        y_hat = imputer.predict(x_test, mask_test)\n\n    # Evaluate model\n    y_true = dataset.numpy()[test_slice]\n    eval_mask = dataset.eval_mask[test_slice]\n\n    for metric, metric_fn in metrics.items():\n        error = metric_fn(y_hat, y_true, eval_mask)\n        print(f'{imputer.name} on {ds_name} {metric}: {error:.4f}')\n\n\nif __name__ == '__main__':\n\n    args = parse_args()\n    print(args.__dict__)\n\n    for ds_name in args.datasets:\n\n        dataset, train_slice = get_dataset(ds_name)\n        args.adj = dataset.get_similarity(thr=0.1)\n\n        # Instantiate imputers\n        imputers = [get_imputer(name, args) for name in args.imputers]\n\n        for imputer in imputers:\n            n_runs = 1 if imputer.is_deterministic else args.n_runs\n            for _ in range(n_runs):\n                run(imputer, dataset, train_slice)\n"
  },
  {
    "path": "scripts/run_imputation.py",
    "content": "import copy\nimport datetime\nimport os\nimport pathlib\nfrom argparse import ArgumentParser\n\nimport numpy as np\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn.functional as F\nimport yaml\nfrom pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\nfrom pytorch_lightning.loggers import TensorBoardLogger\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n\nfrom lib import fillers, datasets, config\nfrom lib.data.datamodule import SpatioTemporalDataModule\nfrom lib.data.imputation_dataset import ImputationDataset, GraphImputationDataset\nfrom lib.nn import models\nfrom lib.nn.utils.metric_base import MaskedMetric\nfrom lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE\nfrom lib.utils import parser_utils, numpy_metrics, ensure_list, prediction_dataframe\nfrom lib.utils.parser_utils import str_to_bool\n\n\ndef has_graph_support(model_cls):\n    return model_cls in [models.GRINet, models.MPGRUNet, models.BiMPGRUNet]\n\n\ndef get_model_classes(model_str):\n    if model_str == 'brits':\n        model, filler = models.BRITSNet, fillers.BRITSFiller\n    elif model_str == 'grin':\n        model, filler = models.GRINet, fillers.GraphFiller\n    elif model_str == 'mpgru':\n        model, filler = models.MPGRUNet, fillers.GraphFiller\n    elif model_str == 'bimpgru':\n        model, filler = models.BiMPGRUNet, fillers.GraphFiller\n    elif model_str == 'var':\n        model, filler = models.VARImputer, fillers.Filler\n    elif model_str == 'gain':\n        model, filler = models.RGAINNet, fillers.RGAINFiller\n    elif model_str == 'birnn':\n        model, filler = models.BiRNNImputer, fillers.MultiImputationFiller\n    elif model_str == 'rnn':\n        model, filler = models.RNNImputer, fillers.Filler\n    else:\n        raise ValueError(f'Model {model_str} not available.')\n    return model, filler\n\n\ndef get_dataset(dataset_name):\n    if dataset_name[:3] == 'air':\n        dataset = datasets.AirQuality(impute_nans=True, small=dataset_name[3:] == '36')\n    elif dataset_name == 'bay_block':\n        dataset = datasets.MissingValuesPemsBay()\n    elif dataset_name == 'la_block':\n        dataset = datasets.MissingValuesMetrLA()\n    elif dataset_name == 'la_point':\n        dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=0.25)\n    elif dataset_name == 'bay_point':\n        dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=0.25)\n    else:\n        raise ValueError(f\"Dataset {dataset_name} not available in this setting.\")\n    return dataset\n\n\ndef parse_args():\n    # Argument parser\n    parser = ArgumentParser()\n    parser.add_argument('--seed', type=int, default=-1)\n    parser.add_argument(\"--model-name\", type=str, default='brits')\n    parser.add_argument(\"--dataset-name\", type=str, default='air36')\n    parser.add_argument(\"--config\", type=str, default=None)\n    # Splitting/aggregation params\n    parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=False)\n    parser.add_argument('--val-len', type=float, default=0.1)\n    parser.add_argument('--test-len', type=float, default=0.2)\n    parser.add_argument('--aggregate-by', type=str, default='mean')\n    # Training params\n    parser.add_argument('--lr', type=float, default=0.001)\n    parser.add_argument('--epochs', type=int, default=300)\n    parser.add_argument('--patience', type=int, default=40)\n    parser.add_argument('--l2-reg', type=float, default=0.)\n    parser.add_argument('--scaled-target', type=str_to_bool, nargs='?', const=True, default=True)\n    parser.add_argument('--grad-clip-val', type=float, default=5.)\n    parser.add_argument('--grad-clip-algorithm', type=str, default='norm')\n    parser.add_argument('--loss-fn', type=str, default='l1_loss')\n    parser.add_argument('--use-lr-schedule', type=str_to_bool, nargs='?', const=True, default=True)\n    parser.add_argument('--consistency-loss', type=str_to_bool, nargs='?', const=True, default=False)\n    parser.add_argument('--whiten-prob', type=float, default=0.05)\n    parser.add_argument('--pred-loss-weight', type=float, default=1.0)\n    parser.add_argument('--warm-up', type=int, default=0)\n    # graph params\n    parser.add_argument(\"--adj-threshold\", type=float, default=0.1)\n    # gain hparams\n    parser.add_argument('--alpha', type=float, default=10.)\n    parser.add_argument('--hint-rate', type=float, default=0.7)\n    parser.add_argument('--g-train-freq', type=int, default=1)\n    parser.add_argument('--d-train-freq', type=int, default=5)\n\n    known_args, _ = parser.parse_known_args()\n    model_cls, _ = get_model_classes(known_args.model_name)\n    parser = model_cls.add_model_specific_args(parser)\n    parser = SpatioTemporalDataModule.add_argparse_args(parser)\n    parser = ImputationDataset.add_argparse_args(parser)\n\n    args = parser.parse_args()\n    if args.config is not None:\n        with open(args.config, 'r') as fp:\n            config_args = yaml.load(fp, Loader=yaml.FullLoader)\n        for arg in config_args:\n            setattr(args, arg, config_args[arg])\n\n    return args\n\n\ndef run_experiment(args):\n    # Set configuration and seed\n    args = copy.deepcopy(args)\n    if args.seed < 0:\n        args.seed = np.random.randint(1e9)\n    torch.set_num_threads(1)\n    pl.seed_everything(args.seed)\n\n    model_cls, filler_cls = get_model_classes(args.model_name)\n    dataset = get_dataset(args.dataset_name)\n\n    ########################################\n    # create logdir and save configuration #\n    ########################################\n\n    exp_name = f\"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}\"\n    logdir = os.path.join(config['logs'], args.dataset_name, args.model_name, exp_name)\n    # save config for logging\n    pathlib.Path(logdir).mkdir(parents=True)\n    with open(os.path.join(logdir, 'config.yaml'), 'w') as fp:\n        yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True)\n\n    ########################################\n    # data module                          #\n    ########################################\n\n    # instantiate dataset\n    dataset_cls = GraphImputationDataset if has_graph_support(model_cls) else ImputationDataset\n    torch_dataset = dataset_cls(*dataset.numpy(return_idx=True),\n                                mask=dataset.training_mask,\n                                eval_mask=dataset.eval_mask,\n                                window=args.window,\n                                stride=args.stride)\n\n    # get train/val/test indices\n    split_conf = parser_utils.filter_function_args(args, dataset.splitter, return_dict=True)\n    train_idxs, val_idxs, test_idxs = dataset.splitter(torch_dataset, **split_conf)\n\n    # configure datamodule\n    data_conf = parser_utils.filter_args(args, SpatioTemporalDataModule, return_dict=True)\n    dm = SpatioTemporalDataModule(torch_dataset, train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs,\n                                  **data_conf)\n    dm.setup()\n\n    # if out of sample in air, add values removed for evaluation in train set\n    if not args.in_sample and args.dataset_name[:3] == 'air':\n        dm.torch_dataset.mask[dm.train_slice] |= dm.torch_dataset.eval_mask[dm.train_slice]\n\n    # get adjacency matrix\n    adj = dataset.get_similarity(thr=args.adj_threshold)\n    # force adj with no self loop\n    np.fill_diagonal(adj, 0.)\n\n    ########################################\n    # predictor                            #\n    ########################################\n\n    # model's inputs\n    additional_model_hparams = dict(adj=adj, d_in=dm.d_in, n_nodes=dm.n_nodes)\n    model_kwargs = parser_utils.filter_args(args={**vars(args), **additional_model_hparams},\n                                            target_cls=model_cls,\n                                            return_dict=True)\n\n    # loss and metrics\n    loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn),\n                           compute_on_step=True,\n                           metric_kwargs={'reduction': 'none'})\n\n    metrics = {'mae': MaskedMAE(compute_on_step=False),\n               'mape': MaskedMAPE(compute_on_step=False),\n               'mse': MaskedMSE(compute_on_step=False),\n               'mre': MaskedMRE(compute_on_step=False)}\n\n    # filler's inputs\n    scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None\n    additional_filler_hparams = dict(model_class=model_cls,\n                                     model_kwargs=model_kwargs,\n                                     optim_class=torch.optim.Adam,\n                                     optim_kwargs={'lr': args.lr,\n                                                   'weight_decay': args.l2_reg},\n                                     loss_fn=loss_fn,\n                                     metrics=metrics,\n                                     scheduler_class=scheduler_class,\n                                     scheduler_kwargs={\n                                         'eta_min': 0.0001,\n                                         'T_max': args.epochs\n                                     },\n                                     alpha=args.alpha,\n                                     hint_rate=args.hint_rate,\n                                     g_train_freq=args.g_train_freq,\n                                     d_train_freq=args.d_train_freq)\n    filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams},\n                                             target_cls=filler_cls,\n                                             return_dict=True)\n    filler = filler_cls(**filler_kwargs)\n\n    ########################################\n    # training                             #\n    ########################################\n\n    # callbacks\n    early_stop_callback = EarlyStopping(monitor='val_mae', patience=args.patience, mode='min')\n    checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mae', mode='min')\n\n    logger = TensorBoardLogger(logdir, name=\"model\")\n\n    trainer = pl.Trainer(max_epochs=args.epochs,\n                         logger=logger,\n                         default_root_dir=logdir,\n                         gpus=1 if torch.cuda.is_available() else None,\n                         gradient_clip_val=args.grad_clip_val,\n                         gradient_clip_algorithm=args.grad_clip_algorithm,\n                         callbacks=[early_stop_callback, checkpoint_callback])\n\n    trainer.fit(filler, datamodule=dm)\n\n    ########################################\n    # testing                              #\n    ########################################\n\n    filler.load_state_dict(torch.load(checkpoint_callback.best_model_path,\n                                      lambda storage, loc: storage)['state_dict'])\n    filler.freeze()\n    trainer.test()\n    filler.eval()\n\n    if torch.cuda.is_available():\n        filler.cuda()\n\n    with torch.no_grad():\n        y_true, y_hat, mask = filler.predict_loader(dm.test_dataloader(), return_mask=True)\n    y_hat = y_hat.detach().cpu().numpy().reshape(y_hat.shape[:3])  # reshape to (eventually) squeeze node channels\n\n    # Test imputations in whole series\n    eval_mask = dataset.eval_mask[dm.test_slice]\n    df_true = dataset.df.iloc[dm.test_slice]\n    metrics = {\n        'mae': numpy_metrics.masked_mae,\n        'mse': numpy_metrics.masked_mse,\n        'mre': numpy_metrics.masked_mre,\n        'mape': numpy_metrics.masked_mape\n    }\n    # Aggregate predictions in dataframes\n    index = dm.torch_dataset.data_timestamps(dm.testset.indices, flatten=False)['horizon']\n    aggr_methods = ensure_list(args.aggregate_by)\n    df_hats = prediction_dataframe(y_hat, index, dataset.df.columns, aggregate_by=aggr_methods)\n    df_hats = dict(zip(aggr_methods, df_hats))\n    for aggr_by, df_hat in df_hats.items():\n        # Compute error\n        print(f'- AGGREGATE BY {aggr_by.upper()}')\n        for metric_name, metric_fn in metrics.items():\n            error = metric_fn(df_hat.values, df_true.values, eval_mask).item()\n            print(f' {metric_name}: {error:.4f}')\n\n    return y_true, y_hat, mask\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    run_experiment(args)\n"
  },
  {
    "path": "scripts/run_synthetic.py",
    "content": "import copy\nimport datetime\nimport os\nimport pathlib\nfrom argparse import ArgumentParser\n\nimport numpy as np\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn.functional as F\nimport yaml\nfrom pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\nfrom pytorch_lightning.loggers import TensorBoardLogger\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\n\nfrom lib import fillers, config\nfrom lib.datasets import ChargedParticles\nfrom lib.nn import models\nfrom lib.nn.utils.metric_base import MaskedMetric\nfrom lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE\nfrom lib.utils import parser_utils\nfrom lib.utils.parser_utils import str_to_bool\n\n\ndef has_graph_support(model_cls):\n    return model_cls is models.GRINet\n\n\ndef get_model_classes(model_str):\n    if model_str == 'brits':\n        model, filler = models.BRITSNet, fillers.BRITSFiller\n    elif model_str == 'grin':\n        model, filler = models.GRINet, fillers.GraphFiller\n    else:\n        raise ValueError(f'Model {model_str} not available.')\n    return model, filler\n\n\ndef parse_args():\n    # Argument parser\n    parser = ArgumentParser()\n    parser.add_argument('--seed', type=int, default=-1)\n    parser.add_argument(\"--model-name\", type=str, default='bigrill')\n    parser.add_argument(\"--config\", type=str, default=None)\n    # Dataset params\n    parser.add_argument('--static-adj', type=str_to_bool, nargs='?', const=True, default=False)\n    parser.add_argument('--window', type=int, default=50)\n    parser.add_argument('--p-block', type=float, default=0.025)\n    parser.add_argument('--p-point', type=float, default=0.025)\n    parser.add_argument('--min-seq', type=int, default=5)\n    parser.add_argument('--max-seq', type=int, default=10)\n    parser.add_argument('--use-exogenous', type=str_to_bool, nargs='?', const=True, default=True)\n    # Splitting/aggregation params\n    parser.add_argument('--val-len', type=float, default=0.1)\n    parser.add_argument('--test-len', type=float, default=0.2)\n    # Training params\n    parser.add_argument('--lr', type=float, default=0.001)\n    parser.add_argument('--epochs', type=int, default=300)\n    parser.add_argument('--patience', type=int, default=40)\n    parser.add_argument('--l2-reg', type=float, default=0.)\n    parser.add_argument('--scaled-target', type=str_to_bool, nargs='?', const=True, default=False)\n    parser.add_argument('--grad-clip-val', type=float, default=5.)\n    parser.add_argument('--grad-clip-algorithm', type=str, default='norm')\n    parser.add_argument('--loss-fn', type=str, default='mse_loss')\n    parser.add_argument('--use-lr-schedule', type=str_to_bool, nargs='?', const=True, default=True)\n    parser.add_argument('--whiten-prob', type=float, default=0.05)\n    parser.add_argument('--pred-loss-weight', type=float, default=1.0)\n    parser.add_argument('--warm-up', type=int, default=0)\n    # graph params\n    parser.add_argument(\"--adj-threshold\", type=float, default=0.1)\n\n    known_args, _ = parser.parse_known_args()\n    model_cls, _ = get_model_classes(known_args.model_name)\n    parser = model_cls.add_model_specific_args(parser)\n\n    args = parser.parse_args()\n    if args.config is not None:\n        with open(args.config, 'r') as fp:\n            config_args = yaml.load(fp, Loader=yaml.FullLoader)\n        for arg in config_args:\n            setattr(args, arg, config_args[arg])\n\n    return args\n\n\ndef run_experiment(args):\n    # Set configuration and seed\n    args = copy.deepcopy(args)\n    if args.seed < 0:\n        args.seed = np.random.randint(1e9)\n    torch.set_num_threads(1)\n    pl.seed_everything(args.seed)\n\n    ########################################\n    # load dataset and model               #\n    ########################################\n\n    model_cls, filler_cls = get_model_classes(args.model_name)\n\n    dataset = ChargedParticles(static_adj=args.static_adj,\n                               window=args.window,\n                               p_block=args.p_block,\n                               p_point=args.p_point,\n                               max_seq=args.max_seq,\n                               min_seq=args.min_seq,\n                               use_exogenous=args.use_exogenous,\n                               graph_mode=has_graph_support(model_cls))\n\n    dataset.split(args.val_len, args.test_len)\n\n    # get adjacency matrix\n    adj = dataset.get_similarity()\n    np.fill_diagonal(adj, 0.)  # force adj with no self loop\n\n    ########################################\n    # create logdir and save configuration #\n    ########################################\n\n    exp_name = f\"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}\"\n    logdir = os.path.join(config['logs'], 'synthetic', args.model_name, exp_name)\n    # save config for logging\n    pathlib.Path(logdir).mkdir(parents=True)\n    with open(os.path.join(logdir, 'config.yaml'), 'w') as fp:\n        yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True)\n\n    ########################################\n    # predictor                            #\n    ########################################\n\n    # model's inputs\n    if has_graph_support(model_cls):\n        model_params = dict(adj=adj, d_in=dataset.n_channels, d_u=dataset.n_exogenous, n_nodes=dataset.n_nodes)\n    else:\n        model_params = dict(d_in=(dataset.n_channels * dataset.n_nodes), d_u=(dataset.n_channels * dataset.n_exogenous))\n    model_kwargs = parser_utils.filter_args(args={**vars(args), **model_params},\n                                            target_cls=model_cls,\n                                            return_dict=True)\n\n    # loss and metrics\n    loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn),\n                           compute_on_step=True,\n                           metric_kwargs={'reduction': 'none'})\n\n    metrics = {'mae': MaskedMAE(compute_on_step=False),\n               'mape': MaskedMAPE(compute_on_step=False),\n               'mse': MaskedMSE(compute_on_step=False),\n               'mre': MaskedMRE(compute_on_step=False)}\n\n    # filler's inputs\n    scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None\n    additional_filler_hparams = dict(model_class=model_cls,\n                                     model_kwargs=model_kwargs,\n                                     optim_class=torch.optim.Adam,\n                                     optim_kwargs={'lr': args.lr,\n                                                   'weight_decay': args.l2_reg},\n                                     loss_fn=loss_fn,\n                                     metrics=metrics,\n                                     scheduler_class=scheduler_class,\n                                     scheduler_kwargs={\n                                         'eta_min': 0.0001,\n                                         'T_max': args.epochs\n                                     },\n                                     alpha=args.alpha,\n                                     hint_rate=args.hint_rate,\n                                     g_train_freq=args.g_train_freq,\n                                     d_train_freq=args.d_train_freq)\n    filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams},\n                                             target_cls=filler_cls,\n                                             return_dict=True)\n    filler = filler_cls(**filler_kwargs)\n\n    ########################################\n    # logging options                      #\n    ########################################\n\n    # log number of parameters\n    args.trainable_parameters = filler.trainable_parameters\n\n    # log statistics on masks\n    for mask_type in ['mask', 'eval_mask', 'training_mask']:\n        mask_type_mean = getattr(dataset, mask_type).float().mean().item()\n        setattr(args, mask_type, mask_type_mean)\n\n    print(args)\n\n    ########################################\n    # training                             #\n    ########################################\n\n    # callbacks\n    early_stop_callback = EarlyStopping(monitor='val_mse', patience=args.patience, mode='min')\n    checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mse', mode='min')\n\n    logger = TensorBoardLogger(logdir, name=\"model\")\n\n    trainer = pl.Trainer(max_epochs=args.epochs,\n                         default_root_dir=logdir,\n                         logger=logger,\n                         gpus=1 if torch.cuda.is_available() else None,\n                         gradient_clip_val=args.grad_clip_val,\n                         gradient_clip_algorithm=args.grad_clip_algorithm,\n                         callbacks=[early_stop_callback, checkpoint_callback])\n\n    trainer.fit(filler,\n                train_dataloader=dataset.train_dataloader(batch_size=args.batch_size),\n                val_dataloaders=dataset.val_dataloader(batch_size=args.batch_size))\n\n    ########################################\n    # testing                              #\n    ########################################\n\n    filler.load_state_dict(torch.load(checkpoint_callback.best_model_path,\n                                      lambda storage, loc: storage)['state_dict'])\n    filler.freeze()\n    trainer.test(filler, test_dataloaders=dataset.test_dataloader(batch_size=args.batch_size))\n    filler.eval()\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    run_experiment(args)\n"
  }
]