main 4a28afbb0926 cached
102 files
201.4 KB
52.6k tokens
371 symbols
1 requests
Download .txt
Showing preview only (224K chars total). Download the full file or copy to clipboard to get everything.
Repository: Graph-Machine-Learning-Group/grin
Branch: main
Commit: 4a28afbb0926
Files: 102
Total size: 201.4 KB

Directory structure:
gitextract__pvx4bn4/

├── .gitignore
├── README.md
├── conda_env.yml
├── config/
│   ├── bimpgru/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   └── la_point.yaml
│   ├── brits/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   ├── la_point.yaml
│   │   └── synthetic.yaml
│   ├── grin/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   ├── la_point.yaml
│   │   └── synthetic.yaml
│   ├── mpgru/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   └── la_point.yaml
│   ├── rgain/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   └── la_point.yaml
│   └── var/
│       ├── air.yaml
│       ├── air36.yaml
│       ├── bay_block.yaml
│       ├── bay_point.yaml
│       ├── irish_block.yaml
│       ├── irish_point.yaml
│       ├── la_block.yaml
│       └── la_point.yaml
├── lib/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── datamodule/
│   │   │   ├── __init__.py
│   │   │   └── spatiotemporal.py
│   │   ├── imputation_dataset.py
│   │   ├── preprocessing/
│   │   │   ├── __init__.py
│   │   │   └── scalers.py
│   │   ├── spatiotemporal_dataset.py
│   │   └── temporal_dataset.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── air_quality.py
│   │   ├── metr_la.py
│   │   ├── pd_dataset.py
│   │   ├── pems_bay.py
│   │   └── synthetic.py
│   ├── fillers/
│   │   ├── __init__.py
│   │   ├── britsfiller.py
│   │   ├── filler.py
│   │   ├── graphfiller.py
│   │   ├── multi_imputation_filler.py
│   │   └── rgainfiller.py
│   ├── nn/
│   │   ├── __init__.py
│   │   ├── layers/
│   │   │   ├── __init__.py
│   │   │   ├── gcrnn.py
│   │   │   ├── gril.py
│   │   │   ├── imputation.py
│   │   │   ├── mpgru.py
│   │   │   ├── rits.py
│   │   │   ├── spatial_attention.py
│   │   │   └── spatial_conv.py
│   │   ├── models/
│   │   │   ├── __init__.py
│   │   │   ├── brits.py
│   │   │   ├── grin.py
│   │   │   ├── mpgru.py
│   │   │   ├── rgain.py
│   │   │   ├── rnn_imputers.py
│   │   │   └── var.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── metric_base.py
│   │       ├── metrics.py
│   │       └── ops.py
│   └── utils/
│       ├── __init__.py
│       ├── numpy_metrics.py
│       ├── parser_utils.py
│       └── utils.py
├── requirements.txt
└── scripts/
    ├── run_baselines.py
    ├── run_imputation.py
    └── run_synthetic.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.DS_STORE


================================================
FILE: README.md
================================================
# Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks (ICLR 2022 - [open review](https://openreview.net/forum?id=kOu3-S3wJ7) - [pdf](https://openreview.net/pdf?id=kOu3-S3wJ7))

[![ICLR](https://img.shields.io/badge/ICLR-2022-blue.svg?style=flat-square)](https://openreview.net/forum?id=kOu3-S3wJ7)
[![PDF](https://img.shields.io/badge/%E2%87%A9-PDF-orange.svg?style=flat-square)](https://openreview.net/pdf?id=kOu3-S3wJ7)
[![arXiv](https://img.shields.io/badge/arXiv-2108.00298-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2108.00298)

This repository contains the code for the reproducibility of the experiments presented in the paper "Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks" (ICLR 2022). In this paper, we propose a graph neural network architecture for multivariate time series imputation and achieve state-of-the-art results on several benchmarks.

**Authors**: [Andrea Cini](mailto:andrea.cini@usi.ch), [Ivan Marisca](mailto:ivan.marisca@usi.ch), Cesare Alippi


**‼️ PyG implementation of GRIN is now available inside [Torch Spatiotemporal](https://github.com/TorchSpatiotemporal/tsl), a library built to accelerate research on neural spatiotemporal data processing methods, with a focus on Graph Neural Networks.**

---

<h2 align=center>GRIN in a nutshell</h2>

The [paper](https://arxiv.org/abs/2108.00298) introduces __GRIN__, a method and an architecture to exploit relational inductive biases to reconstruct missing values in multivariate time series coming from sensor networks. GRIN features a bidirectional recurrent GNN which learns __spatio-temporal node-level representations__ tailored to reconstruct observations at neighboring nodes.

<p align=center>
  <a href="https://github.com/marshka/sinfony">
    <img src="./grin.png" alt="Logo"/>
  </a>
</p>

---

## Directory structure

The directory is structured as follows:

```
.
├── config
│   ├── bimpgru
│   ├── brits
│   ├── grin
│   ├── mpgru
│   ├── rgain
│   └── var
├── datasets
│   ├── air_quality
│   ├── metr_la
│   ├── pems_bay
│   └── synthetic
├── lib
│   ├── __init__.py
│   ├── data
│   ├── datasets
│   ├── fillers
│   ├── nn
│   └── utils
├── requirements.txt
└── scripts
    ├── run_baselines.py
    ├── run_imputation.py
    └── run_synthetic.py

```
Note that, given the size of the files, the datasets are not readily available in the folder. See the next section for the downloading instructions.

## Datasets

All the datasets used in the experiment, except CER-E, are open and can be downloaded from this [link](https://mega.nz/folder/qwwG3Qba#c6qFTeT7apmZKKyEunCzSg). The CER-E dataset can be obtained free of charge for research purposes following the instructions at this [link](https://www.ucd.ie/issda/data/commissionforenergyregulationcer/). We recommend storing the downloaded datasets in a folder named `datasets` inside this directory.

## Configuration files

The `config` directory stores all the configuration files used to run the experiment. They are divided into folders, according to the model.

## Library

The support code, including the models and the datasets readers, are packed in a python library named `lib`. Should you have to change the paths to the datasets location, you have to edit the `__init__.py` file of the library.

## Scripts

The scripts used for the experiment in the paper are in the `scripts` folder.

* `run_baselines.py` is used to compute the metrics for the `MEAN`, `KNN`, `MF` and `MICE` imputation methods. An example of usage is

	```
	python ./scripts/run_baselines.py --datasets air36 air --imputers mean knn --k 10 --in-sample True --n-runs 5
	```

* `run_imputation.py` is used to compute the metrics for the deep imputation methods. An example of usage is

	```
	python ./scripts/run_imputation.py --config config/grin/air36.yaml --in-sample False
	```

* `run_synthetic.py` is used for the experiments on the synthetic datasets. An example of usage is

	```
	python ./scripts/run_synthetic.py --config config/grin/synthetic.yaml --static-adj False
	```

## Requirements

We run all the experiments in `python 3.8`, see `requirements.txt` for the list of `pip` dependencies.

## Bibtex reference

If you find this code useful please consider to cite our paper:

```
@inproceedings{cini2022filling,
    title={Filling the G\_ap\_s: Multivariate Time Series Imputation by Graph Neural Networks},
    author={Andrea Cini and Ivan Marisca and Cesare Alippi},
    booktitle={International Conference on Learning Representations},
    year={2022},
    url={https://openreview.net/forum?id=kOu3-S3wJ7}
}
```


================================================
FILE: conda_env.yml
================================================
name: grin
channels:
  - defaults
  - pytorch
  - conda-forge
dependencies:
  - pip
  - pytables
  - python=3.8
  - pytorch=1.8
  - torchvision
  - torchaudio
  - wheel
  - pip:
      - einops
      - fancyimpute==0.6
      - h5py
      - openpyxl
      - pandas
      - pytorch-lightning==1.4
      - pyyaml
      - scikit-learn
      - scipy
      - tensorboard


================================================
FILE: config/bimpgru/air.yaml
================================================
dataset_name: 'air'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/bimpgru/air36.yaml
================================================
dataset_name: 'air36'
window: 36

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/bimpgru/bay_block.yaml
================================================
dataset_name: 'bay_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/bimpgru/bay_point.yaml
================================================
dataset_name: 'bay_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/bimpgru/irish_block.yaml
================================================
dataset_name: 'irish_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/bimpgru/irish_point.yaml
================================================
dataset_name: 'irish_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/bimpgru/la_block.yaml
================================================
dataset_name: 'la_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/bimpgru/la_point.yaml
================================================
dataset_name: 'la_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'bimpgru'
d_hidden: 64
d_emb: 8
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false
merge: 'mlp'

================================================
FILE: config/brits/air.yaml
================================================
dataset_name: 'air'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 128

================================================
FILE: config/brits/air36.yaml
================================================
dataset_name: 'air36'
window: 36

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 64

================================================
FILE: config/brits/bay_block.yaml
================================================
dataset_name: 'bay_block'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 256

================================================
FILE: config/brits/bay_point.yaml
================================================
dataset_name: 'bay_point'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 256

================================================
FILE: config/brits/irish_block.yaml
================================================
dataset_name: 'irish_block'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 256

================================================
FILE: config/brits/irish_point.yaml
================================================
dataset_name: 'irish_point'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 256

================================================
FILE: config/brits/la_block.yaml
================================================
dataset_name: 'la_block'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 128

================================================
FILE: config/brits/la_point.yaml
================================================
dataset_name: 'la_point'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'brits'
d_hidden: 128

================================================
FILE: config/brits/synthetic.yaml
================================================
window: 36
p_block: 0.025
p_point: 0.025
min_seq: 4
max_seq: 9
use_exogenous: False

epochs: 200
batch_size: 32

model_name: 'brits'
d_hidden: 32

================================================
FILE: config/grin/air.yaml
================================================
dataset_name: 'air'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/air36.yaml
================================================
dataset_name: 'air36'
window: 36

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/bay_block.yaml
================================================
dataset_name: 'bay_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/bay_point.yaml
================================================
dataset_name: 'bay_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/irish_block.yaml
================================================
dataset_name: 'irish_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/irish_point.yaml
================================================
dataset_name: 'irish_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/la_block.yaml
================================================
dataset_name: 'la_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/la_point.yaml
================================================
dataset_name: 'la_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'grin'
pred_loss_weight: 1

d_hidden: 64
d_emb: 8
d_ff: 64
ff_dropout: 0
kernel_size: 2
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/grin/synthetic.yaml
================================================
window: 36
p_block: 0.025
p_point: 0.025
min_seq: 4
max_seq: 9
use_exogenous: False

epochs: 200
batch_size: 32

model_name: 'grin'
d_hidden: 16
d_emb: 0
d_ff: 16
ff_dropout: 0
kernel_size: 1
decoder_order: 1
n_layers: 1
layer_norm: false
merge: 'mlp'


================================================
FILE: config/mpgru/air.yaml
================================================
dataset_name: 'air'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/mpgru/air36.yaml
================================================
dataset_name: 'air36'
window: 36

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/mpgru/bay_block.yaml
================================================
dataset_name: 'bay_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/mpgru/bay_point.yaml
================================================
dataset_name: 'bay_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/mpgru/irish_block.yaml
================================================
dataset_name: 'irish_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/mpgru/irish_point.yaml
================================================
dataset_name: 'irish_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/mpgru/la_block.yaml
================================================
dataset_name: 'la_block'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/mpgru/la_point.yaml
================================================
dataset_name: 'la_point'
window: 24

adj_threshold: 0.1

detrend: False
scale: True
scaling_axis: 'global'  # ['channels', 'global']
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
aggregate_by: ['mean']

model_name: 'mpgru'
pred_loss_weight: 1
d_hidden: 64
d_ff: 64
dropout: 0
kernel_size: 2
n_layers: 1
layer_norm: false

================================================
FILE: config/rgain/air.yaml
================================================
dataset_name: 'air'
window: 24
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 128
d_z: 4
dropout: 0.2
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/rgain/air36.yaml
================================================
dataset_name: 'air36'
window: 36
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 64
d_z: 4
dropout: 0.1
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/rgain/bay_block.yaml
================================================
dataset_name: 'bay_block'
window: 24
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 256
d_z: 4
dropout: 0.2
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/rgain/bay_point.yaml
================================================
dataset_name: 'bay_point'
window: 24
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 256
d_z: 4
dropout: 0.2
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/rgain/irish_block.yaml
================================================
dataset_name: 'irish_block'
window: 24
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 256
d_z: 4
dropout: 0.2
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/rgain/irish_point.yaml
================================================
dataset_name: 'irish_point'
window: 24
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 256
d_z: 4
dropout: 0.2
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/rgain/la_block.yaml
================================================
dataset_name: 'la_block'
window: 24
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 128
d_z: 4
dropout: 0.2
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/rgain/la_point.yaml
================================================
dataset_name: 'la_point'
window: 24
whiten_prob: 0.2

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
batch_size: 32
loss_fn: mse_loss
consistency_loss: False
use_lr_schedule: True
grad_clip_val: -1
aggregate_by: ['mean']

model_name: 'gain'
d_model: 128
d_z: 4
dropout: 0.2
inject_noise: true
alpha: 20
g_train_freq: 3
d_train_freq: 1

================================================
FILE: config/var/air.yaml
================================================
dataset_name: 'air'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: config/var/air36.yaml
================================================
dataset_name: 'air36'
window: 36

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: config/var/bay_block.yaml
================================================
dataset_name: 'bay_block'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: config/var/bay_point.yaml
================================================
dataset_name: 'bay_point'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: config/var/irish_block.yaml
================================================
dataset_name: 'irish_block'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: config/var/irish_point.yaml
================================================
dataset_name: 'irish_point'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: config/var/la_block.yaml
================================================
dataset_name: 'la_block'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: config/var/la_point.yaml
================================================
dataset_name: 'la_point'
window: 24

detrend: False
scale: True
scaling_axis: 'channels'
scaled_target: True

epochs: 300
samples_per_epoch: 5120  # 160 batch of 32
lr: 0.0005
batch_size: 64
aggregate_by: ['mean']

model_name: 'var'
order: 5
padding: 'mean'

================================================
FILE: lib/__init__.py
================================================
import os

base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

config = {
    'logs': 'logs/'
}
datasets_path = {
    'air': 'datasets/air_quality',
    'la': 'datasets/metr_la',
    'bay': 'datasets/pems_bay',
    'synthetic': 'datasets/synthetic'
}
epsilon = 1e-8

for k, v in config.items():
    config[k] = os.path.join(base_dir, v)
for k, v in datasets_path.items():
    datasets_path[k] = os.path.join(base_dir, v)


================================================
FILE: lib/data/__init__.py
================================================
from .temporal_dataset import TemporalDataset
from .spatiotemporal_dataset import SpatioTemporalDataset


================================================
FILE: lib/data/datamodule/__init__.py
================================================
from .spatiotemporal import SpatioTemporalDataModule


================================================
FILE: lib/data/datamodule/spatiotemporal.py
================================================
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Subset, RandomSampler

from .. import TemporalDataset, SpatioTemporalDataset
from ..preprocessing import StandardScaler, MinMaxScaler
from ...utils import ensure_list
from ...utils.parser_utils import str_to_bool


class SpatioTemporalDataModule(pl.LightningDataModule):
    """
    Pytorch Lightning DataModule for TimeSeriesDatasets
    """

    def __init__(self, dataset: TemporalDataset,
                 scale=True,
                 scaling_axis='samples',
                 scaling_type='std',
                 scale_exogenous=None,
                 train_idxs=None,
                 val_idxs=None,
                 test_idxs=None,
                 batch_size=32,
                 workers=1,
                 samples_per_epoch=None):
        super(SpatioTemporalDataModule, self).__init__()
        self.torch_dataset = dataset
        # splitting
        self.trainset = Subset(self.torch_dataset, train_idxs if train_idxs is not None else [])
        self.valset = Subset(self.torch_dataset, val_idxs if val_idxs is not None else [])
        self.testset = Subset(self.torch_dataset, test_idxs if test_idxs is not None else [])
        # preprocessing
        self.scale = scale
        self.scaling_type = scaling_type
        self.scaling_axis = scaling_axis
        self.scale_exogenous = ensure_list(scale_exogenous) if scale_exogenous is not None else None
        # data loaders
        self.batch_size = batch_size
        self.workers = workers
        self.samples_per_epoch = samples_per_epoch

    @property
    def is_spatial(self):
        return isinstance(self.torch_dataset, SpatioTemporalDataset)

    @property
    def n_nodes(self):
        if not self.has_setup_fit:
            raise ValueError('You should initialize the datamodule first.')
        return self.torch_dataset.n_nodes if self.is_spatial else None

    @property
    def d_in(self):
        if not self.has_setup_fit:
            raise ValueError('You should initialize the datamodule first.')
        return self.torch_dataset.n_channels

    @property
    def d_out(self):
        if not self.has_setup_fit:
            raise ValueError('You should initialize the datamodule first.')
        return self.torch_dataset.horizon

    @property
    def train_slice(self):
        return self.torch_dataset.expand_indices(self.trainset.indices, merge=True)

    @property
    def val_slice(self):
        return self.torch_dataset.expand_indices(self.valset.indices, merge=True)

    @property
    def test_slice(self):
        return self.torch_dataset.expand_indices(self.testset.indices, merge=True)

    def get_scaling_axes(self, dim='global'):
        scaling_axis = tuple()
        if dim == 'global':
            scaling_axis = (0, 1, 2)
        elif dim == 'channels':
            scaling_axis = (0, 1)
        elif dim == 'nodes':
            scaling_axis = (0,)
        # Remove last dimension for temporal datasets
        if not self.is_spatial:
            scaling_axis = scaling_axis[:-1]

        if not len(scaling_axis):
            raise ValueError(f'Scaling axis "{dim}" not valid.')

        return scaling_axis

    def get_scaler(self):
        if self.scaling_type == 'std':
            return StandardScaler
        elif self.scaling_type == 'minmax':
            return MinMaxScaler
        else:
            return NotImplementedError

    def setup(self, stage=None):

        if self.scale:
            scaling_axis = self.get_scaling_axes(self.scaling_axis)
            train = self.torch_dataset.data.numpy()[self.train_slice]
            train_mask = self.torch_dataset.mask.numpy()[self.train_slice] if 'mask' in self.torch_dataset else None
            scaler = self.get_scaler()(scaling_axis).fit(train, mask=train_mask, keepdims=True).to_torch()
            self.torch_dataset.scaler = scaler

            if self.scale_exogenous is not None:
                for label in self.scale_exogenous:
                    exo = getattr(self.torch_dataset, label)
                    scaler = self.get_scaler()(scaling_axis)
                    scaler.fit(exo[self.train_slice], keepdims=True).to_torch()
                    setattr(self.torch_dataset, label, scaler.transform(exo))

    def _data_loader(self, dataset, shuffle=False, batch_size=None, **kwargs):
        batch_size = self.batch_size if batch_size is None else batch_size
        return DataLoader(dataset,
                          shuffle=shuffle,
                          batch_size=batch_size,
                          num_workers=self.workers,
                          **kwargs)

    def train_dataloader(self, shuffle=True, batch_size=None):
        if self.samples_per_epoch is not None:
            sampler = RandomSampler(self.trainset, replacement=True, num_samples=self.samples_per_epoch)
            return self._data_loader(self.trainset, False, batch_size, sampler=sampler, drop_last=True)
        return self._data_loader(self.trainset, shuffle, batch_size, drop_last=True)

    def val_dataloader(self, shuffle=False, batch_size=None):
        return self._data_loader(self.valset, shuffle, batch_size)

    def test_dataloader(self, shuffle=False, batch_size=None):
        return self._data_loader(self.testset, shuffle, batch_size)

    @staticmethod
    def add_argparse_args(parser, **kwargs):
        parser.add_argument('--batch-size', type=int, default=64)
        parser.add_argument('--scaling-axis', type=str, default="channels")
        parser.add_argument('--scaling-type', type=str, default="std")
        parser.add_argument('--scale', type=str_to_bool, nargs='?', const=True, default=True)
        parser.add_argument('--workers', type=int, default=0)
        parser.add_argument('--samples-per-epoch', type=int, default=None)
        return parser


================================================
FILE: lib/data/imputation_dataset.py
================================================
import numpy as np
import torch

from . import TemporalDataset, SpatioTemporalDataset


class ImputationDataset(TemporalDataset):

    def __init__(self, data,
                 index=None,
                 mask=None,
                 eval_mask=None,
                 freq=None,
                 trend=None,
                 scaler=None,
                 window=24,
                 stride=1,
                 exogenous=None):
        if mask is None:
            mask = np.ones_like(data)
        if exogenous is None:
            exogenous = dict()
        exogenous['mask_window'] = mask
        if eval_mask is not None:
            exogenous['eval_mask_window'] = eval_mask
        super(ImputationDataset, self).__init__(data,
                                                index=index,
                                                exogenous=exogenous,
                                                trend=trend,
                                                scaler=scaler,
                                                freq=freq,
                                                window=window,
                                                horizon=window,
                                                delay=-window,
                                                stride=stride)

    def get(self, item, preprocess=False):
        res, transform = super(ImputationDataset, self).get(item, preprocess)
        res['x'] = torch.where(res['mask'], res['x'], torch.zeros_like(res['x']))
        return res, transform


class GraphImputationDataset(ImputationDataset, SpatioTemporalDataset):
    pass


================================================
FILE: lib/data/preprocessing/__init__.py
================================================
from .scalers import *


================================================
FILE: lib/data/preprocessing/scalers.py
================================================
from abc import ABC, abstractmethod
import numpy as np


class AbstractScaler(ABC):

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

    def __repr__(self):
        params = ", ".join([f"{k}={str(v)}" for k, v in self.params().items()])
        return "{}({})".format(self.__class__.__name__, params)

    def __call__(self, *args, **kwargs):
        return self.transform(*args, **kwargs)

    def params(self):
        return {k: v for k, v in self.__dict__.items() if not callable(v) and not k.startswith("__")}

    @abstractmethod
    def fit(self, x):
        pass

    @abstractmethod
    def transform(self, x):
        pass

    @abstractmethod
    def inverse_transform(self, x):
        pass

    def fit_transform(self, x):
        self.fit(x)
        return self.transform(x)

    def to_torch(self):
        import torch
        for p in self.params():
            param = getattr(self, p)
            param = np.atleast_1d(param)
            param = torch.tensor(param).float()
            setattr(self, p, param)
        return self


class Scaler(AbstractScaler):
    def __init__(self, offset=0., scale=1.):
        self.bias = offset
        self.scale = scale
        super(Scaler, self).__init__()

    def params(self):
        return dict(bias=self.bias, scale=self.scale)

    def fit(self, x, mask=None, keepdims=True):
        pass

    def transform(self, x):
        return (x - self.bias) / self.scale

    def inverse_transform(self, x):
        return x * self.scale + self.bias

    def fit_transform(self, x, mask=None, keepdims=True):
        self.fit(x, mask, keepdims)
        return self.transform(x)


class StandardScaler(Scaler):
    def __init__(self, axis=0):
        self.axis = axis
        super(StandardScaler, self).__init__()

    def fit(self, x, mask=None, keepdims=True):
        if mask is not None:
            x = np.where(mask, x, np.nan)
            self.bias = np.nanmean(x, axis=self.axis, keepdims=keepdims)
            self.scale = np.nanstd(x, axis=self.axis, keepdims=keepdims)
        else:
            self.bias = x.mean(axis=self.axis, keepdims=keepdims)
            self.scale = x.std(axis=self.axis, keepdims=keepdims)
        return self


class MinMaxScaler(Scaler):
    def __init__(self, axis=0):
        self.axis = axis
        super(MinMaxScaler, self).__init__()

    def fit(self, x, mask=None, keepdims=True):
        if mask is not None:
            x = np.where(mask, x, np.nan)
            self.bias = np.nanmin(x, axis=self.axis, keepdims=keepdims)
            self.scale = (np.nanmax(x, axis=self.axis, keepdims=keepdims) - self.bias)
        else:
            self.bias = x.min(axis=self.axis, keepdims=keepdims)
            self.scale = (x.max(axis=self.axis, keepdims=keepdims) - self.bias)
        return self


================================================
FILE: lib/data/spatiotemporal_dataset.py
================================================
import numpy as np
import pandas as pd
from einops import rearrange

from .temporal_dataset import TemporalDataset


class SpatioTemporalDataset(TemporalDataset):
    def __init__(self, data,
                 index=None,
                 trend=None,
                 scaler=None,
                 freq=None,
                 window=24,
                 horizon=24,
                 delay=0,
                 stride=1,
                 **exogenous):
        """
        Pytorch dataset for data that can be represented as a single TimeSeries

        :param data:
            raw target time series (ts) (can be multivariate), shape: [steps, (features), nodes]
        :param exog:
            global exogenous variables, shape: [steps, nodes]
        :param trend:
            trend time series to be removed from the ts, shape: [steps, (features), (nodes)]
        :param bias:
            bias to be removed from the ts (after de-trending), shape [steps, (features), (nodes)]
        :param scale: r
            scaling factor to scale the ts (after de-trending), shape [steps, (features), (nodes)]
        :param mask:
            mask for valid data, 1 -> valid time step, 0 -> invalid. same shape of ts.
        :param target_exog:
            exogenous variables of the target, shape: [steps, nodes]
        :param window:
            length of windows returned by __get_intem__
        :param horizon:
            length of prediction horizon returned by __get_intem__
        :param delay:
            delay between input and prediction
        """
        super(SpatioTemporalDataset, self).__init__(data,
                                                    index=index,
                                                    trend=trend,
                                                    scaler=scaler,
                                                    freq=freq,
                                                    window=window,
                                                    horizon=horizon,
                                                    delay=delay,
                                                    stride=stride,
                                                    **exogenous)

    def __repr__(self):
        return "{}(n_samples={}, n_nodes={})".format(self.__class__.__name__, len(self), self.n_nodes)

    @property
    def n_nodes(self):
        return self.data.shape[1]

    @staticmethod
    def check_dim(data):
        if data.ndim == 2:  # [steps, nodes] -> [steps, nodes, features]
            data = rearrange(data, 's (n f) -> s n f', f=1)
        elif data.ndim == 1:
            data = rearrange(data, '(s n f) -> s n f', n=1, f=1)
        elif data.ndim == 3:
            pass
        else:
            raise ValueError(f'Invalid data dimensions {data.shape}')
        return data

    def dataframe(self):
        if self.n_channels == 1:
            return pd.DataFrame(data=np.squeeze(self.data, -1), index=self.index)
        raise NotImplementedError()


================================================
FILE: lib/data/temporal_dataset.py
================================================
import numpy as np
import pandas as pd
import torch
from einops import rearrange
from pandas import DatetimeIndex
from torch.utils.data import Dataset

from .preprocessing import AbstractScaler


class TemporalDataset(Dataset):
    def __init__(self, data,
                 index=None,
                 freq=None,
                 exogenous=None,
                 trend=None,
                 scaler=None,
                 window=24,
                 horizon=24,
                 delay=0,
                 stride=1):
        """Wrapper class for dataset whose entry are dependent from a sequence of temporal indices.

        Parameters
        ----------
        data : np.ndarray
            Data relative to the main signal.
        index : DatetimeIndex or None
            Temporal indices for the data.
        exogenous : dict or None
            Exogenous data and label paired with main signal (default is None).
        trend : np.ndarray or None
            Trend paired with main signal (default is None). Must be of the same length of 'data'.
        scaler : AbstractScaler or None
            Scaler that must be used for data (default is None).
        freq : pd.DateTimeIndex.freq or str
            Frequency of the indices (defaults is indices.freq).
        window : int
            Size of the sliding window in the past.
        horizon : int
            Size of the prediction horizon.
        delay : int
            Offset between end of window and start of horizon.

        Raises
        ----------
        ValueError
            If a frequency for the temporal indices is not provided neither in indices nor explicitly.
            If preprocess is True and data_scaler is None.
        """
        super(TemporalDataset, self).__init__()
        # Initialize signatures
        self.__exogenous_keys = dict()
        self.__reserved_signature = {'data', 'trend', 'x', 'y'}
        # Store data
        self.data = data
        if exogenous is not None:
            for name, value in exogenous.items():
                self.add_exogenous(value, name, for_window=True, for_horizon=True)
        # Store time information
        self.index = index
        try:
            freq = freq or index.freq or index.inferred_freq
            self.freq = pd.tseries.frequencies.to_offset(freq)
        except AttributeError:
            self.freq = None
        # Store offset information
        self.window = window
        self.delay = delay
        self.horizon = horizon
        self.stride = stride
        # Identify the indices of the samples
        self._indices = np.arange(self.data.shape[0] - self.sample_span + 1)[::self.stride]
        # Store preprocessing options
        self.trend = trend
        self.scaler = scaler

    def __getitem__(self, item):
        return self.get(item, self.preprocess)

    def __contains__(self, item):
        return item in self.__exogenous_keys

    def __len__(self):
        return len(self._indices)

    def __repr__(self):
        return "{}(n_samples={})".format(self.__class__.__name__, len(self))

    # Getter and setter for data

    @property
    def data(self):
        return self.__data

    @data.setter
    def data(self, value):
        assert value is not None
        self.__data = self.check_input(value)

    @property
    def trend(self):
        return self.__trend

    @trend.setter
    def trend(self, value):
        self.__trend = self.check_input(value)

    # Setter for exogenous data

    def add_exogenous(self, obj, name, for_window=True, for_horizon=False):
        assert isinstance(name, str)
        if name.endswith('_window'):
            name = name[:-7]
            for_window, for_horizon = True, False
        if name.endswith('_horizon'):
            name = name[:-8]
            for_window, for_horizon = False, True
        if name in self.__reserved_signature:
            raise ValueError("Channel '{0}' cannot be added in this way. Use obj.{0} instead.".format(name))
        if not (for_window or for_horizon):
            raise ValueError("Either for_window or for_horizon must be True.")
        obj = self.check_input(obj)
        setattr(self, name, obj)
        self.__exogenous_keys[name] = dict(for_window=for_window, for_horizon=for_horizon)
        return self

    # Dataset properties

    @property
    def horizon_offset(self):
        return self.window + self.delay

    @property
    def sample_span(self):
        return max(self.horizon_offset + self.horizon, self.window)

    @property
    def preprocess(self):
        return (self.trend is not None) or (self.scaler is not None)

    @property
    def n_steps(self):
        return self.data.shape[0]

    @property
    def n_channels(self):
        return self.data.shape[-1]

    @property
    def indices(self):
        return self._indices

    # Signature information

    @property
    def exo_window_keys(self):
        return {k for k, v in self.__exogenous_keys.items() if v['for_window']}

    @property
    def exo_horizon_keys(self):
        return {k for k, v in self.__exogenous_keys.items() if v['for_horizon']}

    @property
    def exo_common_keys(self):
        return self.exo_window_keys.intersection(self.exo_horizon_keys)

    @property
    def signature(self):
        attrs = []
        if self.window > 0:
            attrs.append('x')
            for attr in self.exo_window_keys:
                attrs.append(attr if attr not in self.exo_common_keys else (attr + '_window'))
        for attr in self.exo_horizon_keys:
            attrs.append(attr if attr not in self.exo_common_keys else (attr + '_horizon'))
        attrs.append('y')
        attrs = tuple(attrs)
        preprocess = []
        if self.trend is not None:
            preprocess.append('trend')
        if self.scaler is not None:
            preprocess.extend(self.scaler.params())
        preprocess = tuple(preprocess)
        return dict(data=attrs, preprocessing=preprocess)

    # Item getters

    def get(self, item, preprocess=False):
        idx = self._indices[item]
        res, transform = dict(), dict()
        if self.window > 0:
            res['x'] = self.data[idx:idx + self.window]
            for attr in self.exo_window_keys:
                key = attr if attr not in self.exo_common_keys else (attr + '_window')
                res[key] = getattr(self, attr)[idx:idx + self.window]
        for attr in self.exo_horizon_keys:
            key = attr if attr not in self.exo_common_keys else (attr + '_horizon')
            res[key] = getattr(self, attr)[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]
        res['y'] = self.data[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]
        if preprocess:
            if self.trend is not None:
                y_trend = self.trend[idx + self.horizon_offset:idx + self.horizon_offset + self.horizon]
                res['y'] = res['y'] - y_trend
                transform['trend'] = y_trend
                if 'x' in res:
                    res['x'] = res['x'] - self.trend[idx:idx + self.window]
            if self.scaler is not None:
                transform.update(self.scaler.params())
                if 'x' in res:
                    res['x'] = self.scaler.transform(res['x'])
        return res, transform

    def snapshot(self, indices=None, preprocess=True):
        if not self.preprocess:
            preprocess = False
        data, prep = [{k: [] for k in sign} for sign in self.signature.values()]
        indices = np.arange(len(self._indices)) if indices is None else indices
        for idx in indices:
            data_i, prep_i = self.get(idx, preprocess)
            [v.append(data_i[k]) for k, v in data.items()]
            if len(prep_i):
                [v.append(prep_i[k]) for k, v in prep.items()]
        data = {k: np.stack(ds) for k, ds in data.items() if len(ds)}
        if len(prep):
            prep = {k: np.stack(ds) if k == 'trend' else ds[0] for k, ds in prep.items() if len(ds)}
        return data, prep

    # Data utilities

    def expand_indices(self, indices=None, unique=False, merge=False):
        ds_indices = dict.fromkeys([time for time in ['window', 'horizon'] if getattr(self, time) > 0])
        indices = np.arange(len(self._indices)) if indices is None else indices
        if 'window' in ds_indices:
            w_idxs = [np.arange(idx, idx + self.window) for idx in self._indices[indices]]
            ds_indices['window'] = np.concatenate(w_idxs)
        if 'horizon' in ds_indices:
            h_idxs = [np.arange(idx + self.horizon_offset, idx + self.horizon_offset + self.horizon)
                      for idx in self._indices[indices]]
            ds_indices['horizon'] = np.concatenate(h_idxs)
        if unique:
            ds_indices = {k: np.unique(v) for k, v in ds_indices.items()}
        if merge:
            ds_indices = np.unique(np.concatenate(list(ds_indices.values())))
        return ds_indices

    def overlapping_indices(self, idxs1, idxs2, synch_mode='window', as_mask=False):
        assert synch_mode in ['window', 'horizon']
        ts1 = self.data_timestamps(idxs1, flatten=False)[synch_mode]
        ts2 = self.data_timestamps(idxs2, flatten=False)[synch_mode]
        common_ts = np.intersect1d(np.unique(ts1), np.unique(ts2))
        is_overlapping = lambda sample: np.any(np.in1d(sample, common_ts))
        m1 = np.apply_along_axis(is_overlapping, 1, ts1)
        m2 = np.apply_along_axis(is_overlapping, 1, ts2)
        if as_mask:
            return m1, m2
        return np.sort(idxs1[m1]), np.sort(idxs2[m2])

    def data_timestamps(self, indices=None, flatten=True):
        ds_indices = self.expand_indices(indices, unique=False)
        ds_timestamps = {k: self.index[v] for k, v in ds_indices.items()}
        if not flatten:
            ds_timestamps = {k: np.array(v).reshape(-1, getattr(self, k)) for k, v in ds_timestamps.items()}
        return ds_timestamps

    def reduce_dataset(self, indices, inplace=False):
        if not inplace:
            from copy import deepcopy
            dataset = deepcopy(self)
        else:
            dataset = self
        old_index = dataset.index[dataset._indices[indices]]
        ds_indices = dataset.expand_indices(indices, merge=True)
        dataset.index = dataset.index[ds_indices]
        dataset.data = dataset.data[ds_indices]
        if dataset.mask is not None:
            dataset.mask = dataset.mask[ds_indices]
        if dataset.trend is not None:
            dataset.trend = dataset.trend[ds_indices]
        for attr in dataset.exo_window_keys.union(dataset.exo_horizon_keys):
            if getattr(dataset, attr, None) is not None:
                setattr(dataset, attr, getattr(dataset, attr)[ds_indices])
        dataset._indices = np.flatnonzero(np.in1d(dataset.index, old_index))
        return dataset

    def check_input(self, data):
        if data is None:
            return data
        data = self.check_dim(data)
        data = data.clone().detach() if isinstance(data, torch.Tensor) else torch.tensor(data)
        # cast data
        if torch.is_floating_point(data):
            return data.float()
        elif data.dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]:
            return data.int()
        return data

    # Class-specific methods (override in children)

    @staticmethod
    def check_dim(data):
        if data.ndim == 1:  # [steps] -> [steps, features]
            data = rearrange(data, '(s f) -> s f', f=1)
        elif data.ndim != 2:
            raise ValueError(f'Invalid data dimensions {data.shape}')
        return data

    def dataframe(self):
        return pd.DataFrame(data=self.data, index=self.index)

    @staticmethod
    def add_argparse_args(parser, **kwargs):
        parser.add_argument('--window', type=int, default=24)
        parser.add_argument('--horizon', type=int, default=24)
        parser.add_argument('--delay', type=int, default=0)
        parser.add_argument('--stride', type=int, default=1)
        return parser


================================================
FILE: lib/datasets/__init__.py
================================================
from .air_quality import AirQuality
from .metr_la import MissingValuesMetrLA
from .pems_bay import MissingValuesPemsBay
from .synthetic import ChargedParticles


================================================
FILE: lib/datasets/air_quality.py
================================================
import os

import numpy as np
import pandas as pd

from lib import datasets_path
from .pd_dataset import PandasDataset
from ..utils.utils import disjoint_months, infer_mask, compute_mean, geographical_distance, thresholded_gaussian_kernel


class AirQuality(PandasDataset):
    SEED = 3210

    def __init__(self, impute_nans=False, small=False, freq='60T', masked_sensors=None):
        self.random = np.random.default_rng(self.SEED)
        self.test_months = [3, 6, 9, 12]
        self.infer_eval_from = 'next'
        self.eval_mask = None
        df, dist, mask = self.load(impute_nans=impute_nans, small=small, masked_sensors=masked_sensors)
        self.dist = dist
        if masked_sensors is None:
            self.masked_sensors = list()
        else:
            self.masked_sensors = list(masked_sensors)
        super().__init__(dataframe=df, u=None, mask=mask, name='air', freq=freq, aggr='nearest')

    def load_raw(self, small=False):
        if small:
            path = os.path.join(datasets_path['air'], 'small36.h5')
            eval_mask = pd.DataFrame(pd.read_hdf(path, 'eval_mask'))
        else:
            path = os.path.join(datasets_path['air'], 'full437.h5')
            eval_mask = None
        df = pd.DataFrame(pd.read_hdf(path, 'pm25'))
        stations = pd.DataFrame(pd.read_hdf(path, 'stations'))
        return df, stations, eval_mask

    def load(self, impute_nans=True, small=False, masked_sensors=None):
        # load readings and stations metadata
        df, stations, eval_mask = self.load_raw(small)
        # compute the masks
        mask = (~np.isnan(df.values)).astype('uint8')  # 1 if value is not nan else 0
        if eval_mask is None:
            eval_mask = infer_mask(df, infer_from=self.infer_eval_from)

        eval_mask = eval_mask.values.astype('uint8')
        if masked_sensors is not None:
            eval_mask[:, masked_sensors] = np.where(mask[:, masked_sensors], 1, 0)
        self.eval_mask = eval_mask  # 1 if value is ground-truth for imputation else 0
        # eventually replace nans with weekly mean by hour
        if impute_nans:
            df = df.fillna(compute_mean(df))
        # compute distances from latitude and longitude degrees
        st_coord = stations.loc[:, ['latitude', 'longitude']]
        dist = geographical_distance(st_coord, to_rad=True).values
        return df, dist, mask

    def splitter(self, dataset, val_len=1., in_sample=False, window=0):
        nontest_idxs, test_idxs = disjoint_months(dataset, months=self.test_months, synch_mode='horizon')
        if in_sample:
            train_idxs = np.arange(len(dataset))
            val_months = [(m - 1) % 12 for m in self.test_months]
            _, val_idxs = disjoint_months(dataset, months=val_months, synch_mode='horizon')
        else:
            # take equal number of samples before each month of testing
            val_len = (int(val_len * len(nontest_idxs)) if val_len < 1 else val_len) // len(self.test_months)
            # get indices of first day of each testing month
            delta_idxs = np.diff(test_idxs)
            end_month_idxs = test_idxs[1:][np.flatnonzero(delta_idxs > delta_idxs.min())]
            if len(end_month_idxs) < len(self.test_months):
                end_month_idxs = np.insert(end_month_idxs, 0, test_idxs[0])
            # expand month indices
            month_val_idxs = [np.arange(v_idx - val_len, v_idx) - window for v_idx in end_month_idxs]
            val_idxs = np.concatenate(month_val_idxs) % len(dataset)
            # remove overlapping indices from training set
            ovl_idxs, _ = dataset.overlapping_indices(nontest_idxs, val_idxs, synch_mode='horizon', as_mask=True)
            train_idxs = nontest_idxs[~ovl_idxs]
        return [train_idxs, val_idxs, test_idxs]

    def get_similarity(self, thr=0.1, include_self=False, force_symmetric=False, sparse=False, **kwargs):
        theta = np.std(self.dist[:36, :36])  # use same theta for both air and air36
        adj = thresholded_gaussian_kernel(self.dist, theta=theta, threshold=thr)
        if not include_self:
            adj[np.diag_indices_from(adj)] = 0.
        if force_symmetric:
            adj = np.maximum.reduce([adj, adj.T])
        if sparse:
            import scipy.sparse as sps
            adj = sps.coo_matrix(adj)
        return adj

    @property
    def mask(self):
        return self._mask

    @property
    def training_mask(self):
        return self._mask if self.eval_mask is None else (self._mask & (1 - self.eval_mask))

    def test_interval_mask(self, dtype=bool, squeeze=True):
        m = np.in1d(self.df.index.month, self.test_months).astype(dtype)
        if squeeze:
            return m
        return m[:, None]


================================================
FILE: lib/datasets/metr_la.py
================================================
import os

import numpy as np
import pandas as pd

from lib import datasets_path
from .pd_dataset import PandasDataset
from ..utils import sample_mask


class MetrLA(PandasDataset):
    def __init__(self, impute_zeros=False, freq='5T'):

        df, dist, mask = self.load(impute_zeros=impute_zeros)
        self.dist = dist
        super().__init__(dataframe=df, u=None, mask=mask, name='la', freq=freq, aggr='nearest')

    def load(self, impute_zeros=True):
        path = os.path.join(datasets_path['la'], 'metr_la.h5')
        df = pd.read_hdf(path)
        datetime_idx = sorted(df.index)
        date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T')
        df = df.reindex(index=date_range)
        mask = ~np.isnan(df.values)
        if impute_zeros:
            mask = mask * (df.values != 0.).astype('uint8')
            df = df.replace(to_replace=0., method='ffill')
        else:
            mask = None
        dist = self.load_distance_matrix()
        return df, dist, mask

    def load_distance_matrix(self):
        path = os.path.join(datasets_path['la'], 'metr_la_dist.npy')
        try:
            dist = np.load(path)
        except:
            distances = pd.read_csv(os.path.join(datasets_path['la'], 'distances_la.csv'))
            with open(os.path.join(datasets_path['la'], 'sensor_ids_la.txt')) as f:
                ids = f.read().strip().split(',')
            num_sensors = len(ids)
            dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf
            # Builds sensor id to index map.
            sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)}

            # Fills cells in the matrix with distances.
            for row in distances.values:
                if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:
                    continue
                dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]
            np.save(path, dist)
        return dist

    def get_similarity(self, thr=0.1, force_symmetric=False, sparse=False):
        finite_dist = self.dist.reshape(-1)
        finite_dist = finite_dist[~np.isinf(finite_dist)]
        sigma = finite_dist.std()
        adj = np.exp(-np.square(self.dist / sigma))
        adj[adj < thr] = 0.
        if force_symmetric:
            adj = np.maximum.reduce([adj, adj.T])
        if sparse:
            import scipy.sparse as sps
            adj = sps.coo_matrix(adj)
        return adj

    @property
    def mask(self):
        return self._mask


class MissingValuesMetrLA(MetrLA):
    SEED = 9101112

    def __init__(self, p_fault=0.0015, p_noise=0.05):
        super(MissingValuesMetrLA, self).__init__(impute_zeros=True)
        self.rng = np.random.default_rng(self.SEED)
        self.p_fault = p_fault
        self.p_noise = p_noise
        eval_mask = sample_mask(self.numpy().shape,
                                p=p_fault,
                                p_noise=p_noise,
                                min_seq=12,
                                max_seq=12 * 4,
                                rng=self.rng)
        self.eval_mask = (eval_mask & self.mask).astype('uint8')

    @property
    def training_mask(self):
        return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))

    def splitter(self, dataset, val_len=0, test_len=0, window=0):
        idx = np.arange(len(dataset))
        if test_len < 1:
            test_len = int(test_len * len(idx))
        if val_len < 1:
            val_len = int(val_len * (len(idx) - test_len))
        test_start = len(idx) - test_len
        val_start = test_start - val_len
        return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]]

================================================
FILE: lib/datasets/pd_dataset.py
================================================
import numpy as np
import pandas as pd
import torch


class PandasDataset:
    def __init__(self, dataframe: pd.DataFrame, u: pd.DataFrame = None, name='pd-dataset', mask=None, freq=None,
                 aggr='sum', **kwargs):
        """
        Initialize a tsl dataset from a pandas dataframe.


        :param dataframe: dataframe containing the data, shape: n_steps, n_nodes
        :param u: dataframe with exog variables
        :param name: optional name of the dataset
        :param mask: mask for valid data (1:valid, 0:not valid)
        :param freq: force a frequency (possibly by resampling)
        :param aggr: aggregation method after resampling
        """
        super().__init__()
        self.name = name

        # set dataset dataframe
        self.df = dataframe

        # set optional exog_variable dataframe
        # make sure to consider only the overlapping part of the two dataframes
        # assumption u.index \in df.index
        idx = sorted(self.df.index)
        self.start = idx[0]
        self.end = idx[-1]

        if u is not None:
            self.u = u[self.start:self.end]
        else:
            self.u = None

        if mask is not None:
            mask = np.asarray(mask).astype('uint8')
        self._mask = mask

        if freq is not None:
            self.resample_(freq=freq, aggr=aggr)
        else:
            self.freq = self.df.index.inferred_freq
            # make sure that all the dataframes are aligned
            self.resample_(self.freq, aggr=aggr)

        assert 'T' in self.freq
        self.samples_per_day = int(60 / int(self.freq[:-1]) * 24)

    def __repr__(self):
        return "{}(nodes={}, length={})".format(self.__class__.__name__, self.n_nodes, self.length)

    @property
    def has_mask(self):
        return self._mask is not None

    @property
    def has_u(self):
        return self.u is not None

    def resample_(self, freq, aggr):
        resampler = self.df.resample(freq)
        idx = self.df.index
        if aggr == 'sum':
            self.df = resampler.sum()
        elif aggr == 'mean':
            self.df = resampler.mean()
        elif aggr == 'nearest':
            self.df = resampler.nearest()
        else:
            raise ValueError(f'{aggr} if not a valid aggregation method.')

        if self.has_mask:
            resampler = pd.DataFrame(self._mask, index=idx).resample(freq)
            self._mask = resampler.min().to_numpy()

        if self.has_u:
            resampler = self.u.resample(freq)
            self.u = resampler.nearest()
        self.freq = freq

    def dataframe(self) -> pd.DataFrame:
        return self.df.copy()

    @property
    def length(self):
        return self.df.values.shape[0]

    @property
    def n_nodes(self):
        return self.df.values.shape[1]

    @property
    def mask(self):
        if self._mask is None:
            return np.ones_like(self.df.values).astype('uint8')
        return self._mask

    def numpy(self, return_idx=False):
        if return_idx:
            return self.numpy(), self.df.index
        return self.df.values

    def pytorch(self):
        data = self.numpy()
        return torch.FloatTensor(data)

    def __len__(self):
        return self.length

    @staticmethod
    def build():
        raise NotImplementedError

    def load_raw(self):
        raise NotImplementedError

    def load(self):
        raise NotImplementedError


================================================
FILE: lib/datasets/pems_bay.py
================================================
import os

import numpy as np
import pandas as pd

from lib import datasets_path
from .pd_dataset import PandasDataset
from ..utils import sample_mask


class PemsBay(PandasDataset):
    def __init__(self):
        df, dist, mask = self.load()
        self.dist = dist
        super().__init__(dataframe=df, u=None, mask=mask, name='bay', freq='5T', aggr='nearest')

    def load(self, impute_zeros=True):
        path = os.path.join(datasets_path['bay'], 'pems_bay.h5')
        df = pd.read_hdf(path)
        datetime_idx = sorted(df.index)
        date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T')
        df = df.reindex(index=date_range)
        mask = ~np.isnan(df.values)
        df.fillna(method='ffill', axis=0, inplace=True)
        dist = self.load_distance_matrix(list(df.columns))
        return df.astype('float32'), dist, mask.astype('uint8')

    def load_distance_matrix(self, ids):
        path = os.path.join(datasets_path['bay'], 'pems_bay_dist.npy')
        try:
            dist = np.load(path)
        except:
            distances = pd.read_csv(os.path.join(datasets_path['bay'], 'distances_bay.csv'))
            num_sensors = len(ids)
            dist = np.ones((num_sensors, num_sensors), dtype=np.float32) * np.inf
            # Builds sensor id to index map.
            sensor_id_to_ind = {int(sensor_id): i for i, sensor_id in enumerate(ids)}

            # Fills cells in the matrix with distances.
            for row in distances.values:
                if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:
                    continue
                dist[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]
            np.save(path, dist)
        return dist

    def get_similarity(self, type='dcrnn', thr=0.1, force_symmetric=False, sparse=False):
        """
        Return similarity matrix among nodes. Implemented to match DCRNN.

        :param type: type of similarity matrix.
        :param thr: threshold to increase saprseness.
        :param trainlen: number of steps that can be used for computing the similarity.
        :param force_symmetric: force the result to be simmetric.
        :return: and NxN array representig similarity among nodes.
        """
        if type == 'dcrnn':
            finite_dist = self.dist.reshape(-1)
            finite_dist = finite_dist[~np.isinf(finite_dist)]
            sigma = finite_dist.std()
            adj = np.exp(-np.square(self.dist / sigma))
        elif type == 'stcn':
            sigma = 10
            adj = np.exp(-np.square(self.dist) / sigma)
        else:
            raise NotImplementedError
        adj[adj < thr] = 0.
        if force_symmetric:
            adj = np.maximum.reduce([adj, adj.T])
        if sparse:
            import scipy.sparse as sps
            adj = sps.coo_matrix(adj)
        return adj

    @property
    def mask(self):
        if self._mask is None:
            return self.df.values != 0.
        return self._mask


class MissingValuesPemsBay(PemsBay):
    SEED = 56789

    def __init__(self, p_fault=0.0015, p_noise=0.05):
        super(MissingValuesPemsBay, self).__init__()
        self.rng = np.random.default_rng(self.SEED)
        self.p_fault = p_fault
        self.p_noise = p_noise
        eval_mask = sample_mask(self.numpy().shape,
                                p=p_fault,
                                p_noise=p_noise,
                                min_seq=12,
                                max_seq=12 * 4,
                                rng=self.rng)
        self.eval_mask = (eval_mask & self.mask).astype('uint8')

    @property
    def training_mask(self):
        return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))

    def splitter(self, dataset, val_len=0, test_len=0, window=0):
        idx = np.arange(len(dataset))
        if test_len < 1:
            test_len = int(test_len * len(idx))
        if val_len < 1:
            val_len = int(val_len * (len(idx) - test_len))
        test_start = len(idx) - test_len
        val_start = test_start - val_len
        return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]]


================================================
FILE: lib/datasets/synthetic.py
================================================
import os.path

import numpy as np
import torch
from einops import rearrange
from torch.utils.data import Dataset, DataLoader, Subset

from lib import datasets_path


def generate_mask(shape, p_block=0.01, p_point=0.01, max_seq=1, min_seq=1, rng=None):
    """Generate mask in which 1 denotes valid values, 0 missing ones. Assuming shape=(steps, ...)."""
    if rng is None:
        rand = np.random.random
        randint = np.random.randint
    else:
        rand = rng.random
        randint = rng.integers
    # init mask
    mask = np.ones(shape, dtype='uint8')
    # block missing
    if p_block > 0:
        assert max_seq >= min_seq
        for col in range(shape[1]):
            i = 0
            while i < shape[0]:
                if rand() > p_block:
                    i += 1
                else:
                    fault_len = int(randint(min_seq, max_seq + 1))
                    mask[i:i + fault_len, col] = 0
                    i += fault_len + 1  # at least one valid value between two blocks
    # point missing
    # let values before and after block missing always valid
    diff = np.zeros(mask.shape, dtype='uint8')
    diff[:-1] |= np.diff(mask, axis=0) < 0
    diff[1:] |= np.diff(mask, axis=0) > 0
    mask = np.where(mask - diff, rand(shape) > p_point, mask)
    return mask


class SyntheticDataset(Dataset):
    SEED: int

    def __init__(self, filename,
                 window=None,
                 p_block=0.05,
                 p_point=0.05,
                 max_seq=6,
                 min_seq=4,
                 use_exogenous=True,
                 mask_exogenous=True,
                 graph_mode=True):
        super(SyntheticDataset, self).__init__()
        self.mask_exogenous = mask_exogenous
        self.use_exogenous = use_exogenous
        self.graph_mode = graph_mode
        # fetch data
        content = self.load(filename)
        self.window = window if window is not None else content['loc'].shape[1]
        self.loc = torch.tensor(content['loc'][:, :self.window]).float()
        self.vel = torch.tensor(content['vel'][:, :self.window]).float()
        self.adj = content['adj']
        self.SEED = content['seed'].item()
        # compute masks
        self.rng = np.random.default_rng(self.SEED)
        mask_shape = (len(self), self.window, self.n_nodes, 1)
        mask = generate_mask(mask_shape,
                             p_block=p_block,
                             p_point=p_point,
                             max_seq=max_seq,
                             min_seq=min_seq,
                             rng=self.rng).repeat(self.n_channels, -1)
        eval_mask = 1 - generate_mask(mask_shape,
                                      p_block=p_block,
                                      p_point=p_point,
                                      max_seq=max_seq,
                                      min_seq=min_seq,
                                      rng=self.rng).repeat(self.n_channels, -1)
        self.mask = torch.tensor(mask).byte()
        self.eval_mask = torch.tensor(eval_mask).byte() & self.mask
        # store splitting indices
        self.train_idxs = None
        self.val_idxs = None
        self.test_idxs = None

    def __len__(self):
        return self.loc.size(0)

    def __getitem__(self, index):
        eval_mask = self.eval_mask[index]
        mask = self.training_mask[index]
        x = mask * self.loc[index]
        res = dict(x=x, mask=mask, eval_mask=eval_mask)
        if self.use_exogenous:
            u = self.vel[index]
            if self.mask_exogenous:
                u *= mask.all(-1, keepdims=True)
            res.update(u=u)
        res.update(y=self.loc[index])
        if not self.graph_mode:
            res = {k: rearrange(v, 's n f -> s (n f)') for k, v in res.items()}
        return res

    @property
    def n_channels(self):
        return self.loc.size(-1)

    @property
    def n_nodes(self):
        return self.loc.size(-2)

    @property
    def n_exogenous(self):
        return self.vel.size(-1) if self.use_exogenous else 0

    @property
    def training_mask(self):
        return self.mask if self.eval_mask is None else (self.mask & (1 - self.eval_mask))

    @staticmethod
    def load(filename):
        return np.load(filename)

    def get_similarity(self, sparse=False):
        return self.adj

    # Splitting options

    def split(self, val_len=0, test_len=0):
        idx = np.arange(len(self))
        if test_len < 1:
            test_len = int(test_len * len(idx))
        if val_len < 1:
            val_len = int(val_len * (len(idx) - test_len))
        test_start = len(idx) - test_len
        val_start = test_start - val_len
        # split dataset
        self.train_idxs = idx[:val_start]
        self.val_idxs = idx[val_start:test_start]
        self.test_idxs = idx[test_start:]

    def train_dataloader(self, shuffle=True, batch_size=32):
        return DataLoader(Subset(self, self.train_idxs), shuffle=shuffle, batch_size=batch_size, drop_last=True)

    def val_dataloader(self, shuffle=False, batch_size=32):
        return DataLoader(Subset(self, self.val_idxs), shuffle=shuffle, batch_size=batch_size)

    def test_dataloader(self, shuffle=False, batch_size=32):
        return DataLoader(Subset(self, self.test_idxs), shuffle=shuffle, batch_size=batch_size)


class ChargedParticles(SyntheticDataset):

    def __init__(self, static_adj=False,
                 window=None,
                 p_block=0.05,
                 p_point=0.05,
                 max_seq=6,
                 min_seq=4,
                 use_exogenous=True,
                 mask_exogenous=True,
                 graph_mode=True):
        if static_adj:
            filename = os.path.join(datasets_path['synthetic'], 'charged_static.npz')
        else:
            filename = os.path.join(datasets_path['synthetic'], 'charged_varying.npz')
        self.static_adj = static_adj
        super(ChargedParticles, self).__init__(filename, window,
                                               p_block=p_block,
                                               p_point=p_point,
                                               max_seq=max_seq,
                                               min_seq=min_seq,
                                               use_exogenous=use_exogenous,
                                               mask_exogenous=mask_exogenous,
                                               graph_mode=graph_mode)
        charges = self.load(filename)['charges']
        self.charges = torch.tensor(charges).float()

    def __getitem__(self, item):
        res = super(ChargedParticles, self).__getitem__(item)
        # add charges as exogenous features
        if self.use_exogenous:
            charges = self.charges[item] if not self.static_adj else self.charges
            stacked_charges = charges[None].expand(self.window, -1, -1)
            if not self.graph_mode:
                stacked_charges = rearrange(stacked_charges, 's n f -> s (n f)')
            res.update(u=torch.cat([res['u'], stacked_charges], -1))
        return res

    def get_similarity(self, sparse=False):
        return np.ones((self.n_nodes, self.n_nodes)) - np.eye(self.n_nodes)

    @property
    def n_exogenous(self):
        if self.use_exogenous:
            return super(ChargedParticles, self).n_exogenous + 1  # add charges to features
        return 0


================================================
FILE: lib/fillers/__init__.py
================================================
from .filler import Filler
from .britsfiller import BRITSFiller
from .graphfiller import GraphFiller
from .rgainfiller import RGAINFiller
from .multi_imputation_filler import MultiImputationFiller


================================================
FILE: lib/fillers/britsfiller.py
================================================
import torch

from . import Filler
from ..nn import BRITS


class BRITSFiller(Filler):

    def training_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        mask = batch_data['mask'].clone().detach()
        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute predictions and compute loss
        out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations]
            predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions]

        loss = sum([self.loss_fn(pred, target, mask) for pred in predictions])
        loss += BRITS.consistency_loss(*imputations)

        # Logging
        metrics_mask = (mask | eval_mask) - batch_data['mask']  # all unseen data
        out = self._postprocess(out, batch_preprocessing)
        self.train_metrics.update(out.detach(), y, metrics_mask)
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('train_loss', loss, on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        mask = batch_data.get('mask')
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute predictions and compute loss
        out, imputations, predictions = self.predict_batch(batch, preprocess=False, postprocess=False)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            predictions = [self._postprocess(prd, batch_preprocessing) for prd in predictions]

        val_loss = sum([self.loss_fn(pred, target, mask) for pred in predictions])

        # Logging
        out = self._postprocess(out, batch_preprocessing)
        self.val_metrics.update(out.detach(), y, eval_mask)
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return val_loss

    def test_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute outputs and rescale
        imputation, *_ = self.predict_batch(batch, preprocess=False, postprocess=True)
        test_loss = self.loss_fn(imputation, y, eval_mask)

        # Logging
        self.test_metrics.update(imputation.detach(), y, eval_mask)
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return test_loss


================================================
FILE: lib/fillers/filler.py
================================================
import inspect
from copy import deepcopy

import pytorch_lightning as pl
import torch
from pytorch_lightning.core.decorators import auto_move_data
from pytorch_lightning.metrics import MetricCollection
from pytorch_lightning.utilities import move_data_to_device

from .. import epsilon
from ..nn.utils.metric_base import MaskedMetric
from ..utils.utils import ensure_list


class Filler(pl.LightningModule):
    def __init__(self,
                 model_class,
                 model_kwargs,
                 optim_class,
                 optim_kwargs,
                 loss_fn,
                 scaled_target=False,
                 whiten_prob=0.05,
                 metrics=None,
                 scheduler_class=None,
                 scheduler_kwargs=None):
        """
        PL module to implement hole fillers.

        :param model_class: Class of pytorch nn.Module implementing the imputer.
        :param model_kwargs: Model's keyword arguments.
        :param optim_class: Optimizer class.
        :param optim_kwargs: Optimizer's keyword arguments.
        :param loss_fn: Loss function used for training.
        :param scaled_target: Whether to scale target before computing loss using batch processing information.
        :param whiten_prob: Probability of removing a value and using it as ground truth for imputation.
        :param metrics: Dictionary of type {'metric1_name':metric1_fn, 'metric2_name':metric2_fn ...}.
        :param scheduler_class: Scheduler class.
        :param scheduler_kwargs: Scheduler's keyword arguments.
        """
        super(Filler, self).__init__()
        self.save_hyperparameters(model_kwargs)
        self.model_cls = model_class
        self.model_kwargs = model_kwargs
        self.optim_class = optim_class
        self.optim_kwargs = optim_kwargs
        self.scheduler_class = scheduler_class
        if scheduler_kwargs is None:
            self.scheduler_kwargs = dict()
        else:
            self.scheduler_kwargs = scheduler_kwargs

        if loss_fn is not None:
            self.loss_fn = self._check_metric(loss_fn, on_step=True)
        else:
            self.loss_fn = None

        self.scaled_target = scaled_target

        # during training whiten ground-truth values with this probability
        assert 0. <= whiten_prob <= 1.
        self.keep_prob = 1. - whiten_prob

        if metrics is None:
            metrics = dict()
        self._set_metrics(metrics)
        # instantiate model
        self.model = self.model_cls(**self.model_kwargs)

    def reset_model(self):
        self.model = self.model_cls(**self.model_kwargs)

    @property
    def trainable_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    @auto_move_data
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    @staticmethod
    def _check_metric(metric, on_step=False):
        if not isinstance(metric, MaskedMetric):
            if 'reduction' in inspect.getfullargspec(metric).args:
                metric_kwargs = {'reduction': 'none'}
            else:
                metric_kwargs = dict()
            return MaskedMetric(metric, compute_on_step=on_step, metric_kwargs=metric_kwargs)
        return deepcopy(metric)

    def _set_metrics(self, metrics):
        self.train_metrics = MetricCollection(
            {f'train_{k}': self._check_metric(m, on_step=True) for k, m in metrics.items()})
        self.val_metrics = MetricCollection({f'val_{k}': self._check_metric(m) for k, m in metrics.items()})
        self.test_metrics = MetricCollection({f'test_{k}': self._check_metric(m) for k, m in metrics.items()})

    def _preprocess(self, data, batch_preprocessing):
        """
        Perform preprocessing of a given input.

        :param data: pytorch tensor of shape [batch, steps, nodes, features] to preprocess
        :param batch_preprocessing: dictionary containing preprocessing data
        :return: preprocessed data
        """
        if isinstance(data, (list, tuple)):
            return [self._preprocess(d, batch_preprocessing) for d in data]
        trend = batch_preprocessing.get('trend', 0.)
        bias = batch_preprocessing.get('bias', 0.)
        scale = batch_preprocessing.get('scale', 1.)
        return (data - trend - bias) / (scale + epsilon)

    def _postprocess(self, data, batch_preprocessing):
        """
        Perform postprocessing (inverse transform) of a given input.

        :param data: pytorch tensor of shape [batch, steps, nodes, features] to trasform
        :param batch_preprocessing: dictionary containing preprocessing data
        :return: inverse transformed data
        """
        if isinstance(data, (list, tuple)):
            return [self._postprocess(d, batch_preprocessing) for d in data]
        trend = batch_preprocessing.get('trend', 0.)
        bias = batch_preprocessing.get('bias', 0.)
        scale = batch_preprocessing.get('scale', 1.)
        return data * (scale + epsilon) + bias + trend

    def predict_batch(self, batch, preprocess=False, postprocess=True, return_target=False):
        """
        This method takes as an input a batch as a two dictionaries containing tensors and outputs the predictions.
        Prediction should have a shape [batch, nodes, horizon]

        :param batch: list dictionary following the structure [data:
                                                                {'x':[...], 'y':[...], 'u':[...], ...},
                                                              preprocessing:
                                                                {'bias': ..., 'scale': ..., 'x_trend':[...], 'y_trend':[...]}]
        :param preprocess: whether the data need to be preprocessed (note that inputs are by default preprocessed before creating the batch)
        :param postprocess: whether to postprocess the predictions (if True we assume that the model has learned to predict the trasformed signal)
        :param return_target: whether to return the prediction target y_true and the prediction mask
        :return: (y_true), y_hat, (mask)
        """
        batch_data, batch_preprocessing = self._unpack_batch(batch)
        if preprocess:
            x = batch_data.pop('x')
            x = self._preprocess(x, batch_preprocessing)
            y_hat = self.forward(x, **batch_data)
        else:
            y_hat = self.forward(**batch_data)
        # Rescale outputs
        if postprocess:
            y_hat = self._postprocess(y_hat, batch_preprocessing)
        if return_target:
            y = batch_data.get('y')
            mask = batch_data.get('mask', None)
            return y, y_hat, mask
        return y_hat

    def predict_loader(self, loader, preprocess=False, postprocess=True, return_mask=True):
        """
        Makes predictions for an input dataloader. Returns both the predictions and the predictions targets.

        :param loader: torch dataloader
        :param preprocess: whether to preprocess the data
        :param postprocess: whether to postprocess the data
        :param return_mask: whether to return the valid mask (if it exists)
        :return: y_true, y_hat
        """
        targets, imputations, masks = [], [], []
        for batch in loader:
            batch = move_data_to_device(batch, self.device)
            batch_data, batch_preprocessing = self._unpack_batch(batch)
            # Extract mask and target
            eval_mask = batch_data.pop('eval_mask', None)
            y = batch_data.pop('y')

            y_hat = self.predict_batch(batch, preprocess=preprocess, postprocess=postprocess)

            if isinstance(y_hat, (list, tuple)):
                y_hat = y_hat[0]

            targets.append(y)
            imputations.append(y_hat)
            masks.append(eval_mask)

        y = torch.cat(targets, 0)
        y_hat = torch.cat(imputations, 0)
        if return_mask:
            mask = torch.cat(masks, 0) if masks[0] is not None else None
            return y, y_hat, mask
        return y, y_hat

    def _unpack_batch(self, batch):
        """
        Unpack a batch into data and preprocessing dictionaries.

        :param batch: the batch
        :return: batch_data, batch_preprocessing
        """
        if isinstance(batch, (tuple, list)) and (len(batch) == 2):
            batch_data, batch_preprocessing = batch
        else:
            batch_data = batch
            batch_preprocessing = dict()
        return batch_data, batch_preprocessing

    def training_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        mask = batch_data['mask'].clone().detach()
        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
        eval_mask = batch_data.pop('eval_mask')
        eval_mask = (mask | eval_mask) - batch_data['mask']

        y = batch_data.pop('y')

        # Compute predictions and compute loss
        imputation = self.predict_batch(batch, preprocess=False, postprocess=False)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputation = self._postprocess(imputation, batch_preprocessing)

        loss = self.loss_fn(imputation, target, mask)

        # Logging
        if self.scaled_target:
            imputation = self._postprocess(imputation, batch_preprocessing)
        self.train_metrics.update(imputation.detach(), y, eval_mask)  # all unseen data
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute predictions and compute loss
        imputation = self.predict_batch(batch, preprocess=False, postprocess=False)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputation = self._postprocess(imputation, batch_preprocessing)

        val_loss = self.loss_fn(imputation, target, eval_mask)

        # Logging
        if self.scaled_target:
            imputation = self._postprocess(imputation, batch_preprocessing)
        self.val_metrics.update(imputation.detach(), y, eval_mask)
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return val_loss

    def test_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute outputs and rescale
        imputation = self.predict_batch(batch, preprocess=False, postprocess=True)
        test_loss = self.loss_fn(imputation, y, eval_mask)

        # Logging
        self.test_metrics.update(imputation.detach(), y, eval_mask)
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        return test_loss

    def on_train_epoch_start(self) -> None:
        optimizers = ensure_list(self.optimizers())
        for i, optimizer in enumerate(optimizers):
            lr = optimizer.optimizer.param_groups[0]['lr']
            self.log(f'lr_{i}', lr, on_step=False, on_epoch=True, logger=True, prog_bar=False)

    def configure_optimizers(self):
        cfg = dict()
        optimizer = self.optim_class(self.parameters(), **self.optim_kwargs)
        cfg['optimizer'] = optimizer
        if self.scheduler_class is not None:
            metric = self.scheduler_kwargs.pop('monitor', None)
            scheduler = self.scheduler_class(optimizer, **self.scheduler_kwargs)
            cfg['lr_scheduler'] = scheduler
            if metric is not None:
                cfg['monitor'] = metric
        return cfg


================================================
FILE: lib/fillers/graphfiller.py
================================================
import torch

from . import Filler
from ..nn.models import MPGRUNet, GRINet, BiMPGRUNet


class GraphFiller(Filler):

    def __init__(self,
                 model_class,
                 model_kwargs,
                 optim_class,
                 optim_kwargs,
                 loss_fn,
                 scaled_target=False,
                 whiten_prob=0.05,
                 pred_loss_weight=1.,
                 warm_up=0,
                 metrics=None,
                 scheduler_class=None,
                 scheduler_kwargs=None):
        super(GraphFiller, self).__init__(model_class=model_class,
                                          model_kwargs=model_kwargs,
                                          optim_class=optim_class,
                                          optim_kwargs=optim_kwargs,
                                          loss_fn=loss_fn,
                                          scaled_target=scaled_target,
                                          whiten_prob=whiten_prob,
                                          metrics=metrics,
                                          scheduler_class=scheduler_class,
                                          scheduler_kwargs=scheduler_kwargs)

        self.tradeoff = pred_loss_weight
        if model_class is MPGRUNet:
            self.trimming = (warm_up, 0)
        elif model_class in [GRINet, BiMPGRUNet]:
            self.trimming = (warm_up, warm_up)

    def trim_seq(self, *seq):
        seq = [s[:, self.trimming[0]:s.size(1) - self.trimming[1]] for s in seq]
        if len(seq) == 1:
            return seq[0]
        return seq

    def training_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Compute masks
        mask = batch_data['mask'].clone().detach()
        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
        eval_mask = batch_data.pop('eval_mask', None)
        eval_mask = (mask | eval_mask) - batch_data['mask']  # all unseen data

        y = batch_data.pop('y')

        # Compute predictions and compute loss
        res = self.predict_batch(batch, preprocess=False, postprocess=False)
        imputation, predictions = (res[0], res[1:]) if isinstance(res, (list, tuple)) else (res, [])

        # trim to imputation horizon len
        imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y)
        predictions = self.trim_seq(*predictions)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputation = self._postprocess(imputation, batch_preprocessing)
            for i, _ in enumerate(predictions):
                predictions[i] = self._postprocess(predictions[i], batch_preprocessing)

        loss = self.loss_fn(imputation, target, mask)
        for pred in predictions:
            loss += self.tradeoff * self.loss_fn(pred, target, mask)

        # Logging
        if self.scaled_target:
            imputation = self._postprocess(imputation, batch_preprocessing)
        self.train_metrics.update(imputation.detach(), y, eval_mask)  # all unseen data
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        mask = batch_data.get('mask')
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute predictions and compute loss
        imputation = self.predict_batch(batch, preprocess=False, postprocess=False)

        # trim to imputation horizon len
        imputation, mask, eval_mask, y = self.trim_seq(imputation, mask, eval_mask, y)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputation = self._postprocess(imputation, batch_preprocessing)

        val_loss = self.loss_fn(imputation, target, eval_mask)

        # Logging
        if self.scaled_target:
            imputation = self._postprocess(imputation, batch_preprocessing)
        self.val_metrics.update(imputation.detach(), y, eval_mask)
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return val_loss

    def test_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute outputs and rescale
        imputation = self.predict_batch(batch, preprocess=False, postprocess=True)
        test_loss = self.loss_fn(imputation, y, eval_mask)

        # Logging
        self.test_metrics.update(imputation.detach(), y, eval_mask)
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('test_loss', test_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return test_loss


================================================
FILE: lib/fillers/multi_imputation_filler.py
================================================
import torch
from pytorch_lightning.core.decorators import auto_move_data

from . import Filler


class MultiImputationFiller(Filler):
    """
    Filler with multiple imputation outputs
    """

    def __init__(self,
                 model_class,
                 model_kwargs,
                 optim_class,
                 optim_kwargs,
                 loss_fn,
                 consistency_loss=False,
                 scaled_target=False,
                 whiten_prob=0.05,
                 metrics=None,
                 scheduler_class=None,
                 scheduler_kwargs=None):

        super().__init__(model_class,
                         model_kwargs,
                         optim_class,
                         optim_kwargs,
                         loss_fn,
                         scaled_target,
                         whiten_prob,
                         metrics,
                         scheduler_class,
                         scheduler_kwargs)
        self.consistency_loss = consistency_loss

    @auto_move_data
    def forward(self, *args, **kwargs):
        out = self.model(*args, **kwargs)
        assert isinstance(out, (list, tuple))
        if self.training:
            return out
        return out[0]  # we assume that the final imputation is the first one

    def _consistency_loss(self, imputations, mask):
        from itertools import combinations
        return sum([self.loss_fn(imp1, imp2, mask) for imp1, imp2 in combinations(imputations, 2)])

    def training_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # Extract mask and target
        mask = batch_data['mask'].clone().detach()
        batch_data['mask'] = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        # Compute predictions and compute loss
        imputations = self.predict_batch(batch, preprocess=False, postprocess=False)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputations = [self._postprocess(imp, batch_preprocessing) for imp in imputations]

        loss = sum([self.loss_fn(imp, target, mask) for imp in imputations])
        if self.consistency_loss:
            loss += self._consistency_loss(imputations, mask)

        # Logging
        metrics_mask = (mask | eval_mask) - batch_data['mask']  # all unseen data

        x_hat = imputations[0]
        x_hat = self._postprocess(x_hat, batch_preprocessing)
        self.train_metrics.update(x_hat.detach(), y, metrics_mask)
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return loss


================================================
FILE: lib/fillers/rgainfiller.py
================================================
import torch
from torch.nn import functional as F

from .multi_imputation_filler import MultiImputationFiller
from ..nn.utils.metric_base import MaskedMetric


class MaskedBCEWithLogits(MaskedMetric):
    def __init__(self,
                 mask_nans=False,
                 mask_inf=False,
                 compute_on_step=True,
                 dist_sync_on_step=False,
                 process_group=None,
                 dist_sync_fn=None,
                 at=None):
        super(MaskedBCEWithLogits, self).__init__(metric_fn=F.binary_cross_entropy_with_logits,
                                                  mask_nans=mask_nans,
                                                  mask_inf=mask_inf,
                                                  compute_on_step=compute_on_step,
                                                  dist_sync_on_step=dist_sync_on_step,
                                                  process_group=process_group,
                                                  dist_sync_fn=dist_sync_fn,
                                                  metric_kwargs={'reduction': 'none'},
                                                  at=at)


class RGAINFiller(MultiImputationFiller):
    def __init__(self,
                 model_class,
                 model_kwargs,
                 optim_class,
                 optim_kwargs,
                 loss_fn,
                 g_train_freq=1,
                 d_train_freq=5,
                 consistency_loss=False,
                 scaled_target=True,
                 whiten_prob=0.05,
                 hint_rate=0.7,
                 alpha=10.,
                 metrics=None,
                 scheduler_class=None,
                 scheduler_kwargs=None):
        super(RGAINFiller, self).__init__(model_class=model_class,
                                          model_kwargs=model_kwargs,
                                          optim_class=optim_class,
                                          optim_kwargs=optim_kwargs,
                                          loss_fn=loss_fn,
                                          scaled_target=scaled_target,
                                          whiten_prob=whiten_prob,
                                          metrics=metrics,
                                          consistency_loss=consistency_loss,
                                          scheduler_class=scheduler_class,
                                          scheduler_kwargs=scheduler_kwargs)
        # discriminator training params
        self.alpha = alpha
        self.g_train_freq = g_train_freq
        self.d_train_freq = d_train_freq
        self.masked_bce_loss = MaskedBCEWithLogits(compute_on_step=True)
        # activate manual optimization
        self.automatic_optimization = False
        self.hint_rate = hint_rate

    def training_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)
        g_opt, d_opt = self.optimizers()
        schedulers = self.lr_schedulers()

        # Extract mask and target
        x = batch_data.pop('x')
        mask = batch_data['mask'].clone().detach()
        training_mask = torch.bernoulli(mask.clone().detach().float() * self.keep_prob).byte()
        eval_mask = batch_data.pop('eval_mask', None)
        y = batch_data.pop('y')

        ##########################
        #  generate imputations
        ##########################

        imputations = self.model.generator(x, training_mask)
        imputed_seq = imputations[0]
        target = self._preprocess(y, batch_preprocessing)
        y_hat = self._postprocess(imputed_seq, batch_preprocessing)

        x_in = training_mask * x + (1 - training_mask) * imputed_seq
        hint = torch.rand_like(training_mask, dtype=torch.float) < self.hint_rate
        hint = hint.byte()
        hint = hint * training_mask + (1 - hint) * 0.5

        #########################
        #  train generator
        #########################
        if (batch_idx % self.g_train_freq) == 0:

            g_opt.zero_grad()

            rec_loss = sum([torch.sqrt(self.loss_fn(imp, target, mask)) for imp in imputations])
            if self.consistency_loss:
                rec_loss += self._consistency_loss(imputations, mask)

            logits = self.model.discriminator(x_in, hint)
            # maximize logit
            adv_loss = self.masked_bce_loss(logits, torch.ones_like(logits), 1 - training_mask)

            g_loss = self.alpha * rec_loss + adv_loss

            self.manual_backward(g_loss)
            g_opt.step()

            # Logging
            metrics_mask = (mask | eval_mask) - training_mask
            self.train_metrics.update(y_hat.detach(), y, metrics_mask)  # all unseen data
            self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
            self.log('gen_loss', adv_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
            self.log('imp_loss', rec_loss.detach(), on_step=True, on_epoch=True, logger=True, prog_bar=True)

        ###########################
        # train discriminator
        ###########################

        if (batch_idx % self.d_train_freq) == 0:
            d_opt.zero_grad()

            logits = self.model.discriminator(x_in.detach(), hint)
            d_loss = self.masked_bce_loss(logits, training_mask.to(logits.dtype))

            self.manual_backward(d_loss)
            d_opt.step()
            self.log('d_loss', d_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)

        if (schedulers is not None) and self.trainer.is_last_batch:
            for sch in schedulers:
                sch.step()

    def configure_optimizers(self):
        opt_g = self.optim_class(self.model.generator.parameters(), **self.optim_kwargs)
        opt_d = self.optim_class(self.model.discriminator.parameters(), **self.optim_kwargs)
        optimizers = [opt_g, opt_d]
        if self.scheduler_class is not None:
            metric = self.scheduler_kwargs.pop('monitor', None)
            schedulers = [{"scheduler": self.scheduler_class(opt, **self.scheduler_kwargs), "monitor": metric}
                          for opt in optimizers]
            return optimizers, schedulers
        return optimizers


================================================
FILE: lib/nn/__init__.py
================================================
from .layers import *


================================================
FILE: lib/nn/layers/__init__.py
================================================
from .rits import RITS, BRITS
from .gril import GRIL, BiGRIL
from .spatial_conv import SpatialConvOrderK
from .mpgru import MPGRUImputer


================================================
FILE: lib/nn/layers/gcrnn.py
================================================
import torch
import torch.nn as nn

from .spatial_conv import SpatialConvOrderK


class GCGRUCell(nn.Module):
    """
    Graph Convolution Gated Recurrent Unit Cell.
    """

    def __init__(self, d_in, num_units, support_len, order, activation='tanh'):
        """
        :param num_units: the hidden dim of rnn
        :param support_len: the (weighted) adjacency matrix of the graph, in numpy ndarray form
        :param order: the max diffusion step
        :param activation: if None, don't do activation for cell state
        """
        super(GCGRUCell, self).__init__()
        self.activation_fn = getattr(torch, activation)

        self.forget_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len,
                                             order=order)
        self.update_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len,
                                             order=order)
        self.c_gate = SpatialConvOrderK(c_in=d_in + num_units, c_out=num_units, support_len=support_len, order=order)

    def forward(self, x, h, adj):
        """
        :param x: (B, input_dim, num_nodes)
        :param h: (B, num_units, num_nodes)
        :param adj: (num_nodes, num_nodes)
        :return:
        """
        # we start with bias 1.0 to not reset and not update
        x_gates = torch.cat([x, h], dim=1)
        r = torch.sigmoid(self.forget_gate(x_gates, adj))
        u = torch.sigmoid(self.update_gate(x_gates, adj))
        x_c = torch.cat([x, r * h], dim=1)
        c = self.c_gate(x_c, adj)  # batch_size, self._num_nodes * output_size
        c = self.activation_fn(c)
        return u * h + (1. - u) * c


class GCRNN(nn.Module):
    def __init__(self,
                 d_in,
                 d_model,
                 d_out,
                 n_layers,
                 support_len,
                 kernel_size=2):
        super(GCRNN, self).__init__()
        self.d_in = d_in
        self.d_model = d_model
        self.d_out = d_out
        self.n_layers = n_layers
        self.ks = kernel_size
        self.support_len = support_len
        self.rnn_cells = nn.ModuleList()
        for i in range(self.n_layers):
            self.rnn_cells.append(GCGRUCell(d_in=self.d_in if i == 0 else self.d_model,
                                            num_units=self.d_model, support_len=self.support_len, order=self.ks))
        self.output_layer = nn.Conv2d(self.d_model, self.d_out, kernel_size=1)

    def init_hidden_states(self, x):
        return [torch.zeros(size=(x.shape[0], self.d_model, x.shape[2])).to(x.device) for _ in range(self.n_layers)]

    def single_pass(self, x, h, adj):
        out = x
        for l, layer in enumerate(self.rnn_cells):
            out = h[l] = layer(out, h[l], adj)
        return out, h

    def forward(self, x, adj, h=None):
        # x:[batch, features, nodes, steps]
        *_, steps = x.size()
        if h is None:
            h = self.init_hidden_states(x)
        # temporal conv
        for step in range(steps):
            out, h = self.single_pass(x[..., step], h, adj)

        return self.output_layer(out[..., None])


================================================
FILE: lib/nn/layers/gril.py
================================================
import torch
import torch.nn as nn
from einops import rearrange

from .spatial_conv import SpatialConvOrderK
from .gcrnn import GCGRUCell
from .spatial_attention import SpatialAttention
from ..utils.ops import reverse_tensor


class SpatialDecoder(nn.Module):
    def __init__(self, d_in, d_model, d_out, support_len, order=1, attention_block=False, nheads=2, dropout=0.):
        super(SpatialDecoder, self).__init__()
        self.order = order
        self.lin_in = nn.Conv1d(d_in, d_model, kernel_size=1)
        self.graph_conv = SpatialConvOrderK(c_in=d_model, c_out=d_model,
                                            support_len=support_len * order, order=1, include_self=False)
        if attention_block:
            self.spatial_att = SpatialAttention(d_in=d_model,
                                                d_model=d_model,
                                                nheads=nheads,
                                                dropout=dropout)
            self.lin_out = nn.Conv1d(3 * d_model, d_model, kernel_size=1)
        else:
            self.register_parameter('spatial_att', None)
            self.lin_out = nn.Conv1d(2 * d_model, d_model, kernel_size=1)
        self.read_out = nn.Conv1d(2 * d_model, d_out, kernel_size=1)
        self.activation = nn.PReLU()
        self.adj = None

    def forward(self, x, m, h, u, adj, cached_support=False):
        # [batch, channels, nodes]
        x_in = [x, m, h] if u is None else [x, m, u, h]
        x_in = torch.cat(x_in, 1)
        if self.order > 1:
            if cached_support and (self.adj is not None):
                adj = self.adj
            else:
                adj = SpatialConvOrderK.compute_support_orderK(adj, self.order, include_self=False, device=x_in.device)
                self.adj = adj if cached_support else None

        x_in = self.lin_in(x_in)
        out = self.graph_conv(x_in, adj)
        if self.spatial_att is not None:
            # [batch, channels, nodes] -> [batch, steps, nodes, features]
            x_in = rearrange(x_in, 'b f n -> b 1 n f')
            out_att = self.spatial_att(x_in, torch.eye(x_in.size(2), dtype=torch.bool, device=x_in.device))
            out_att = rearrange(out_att, 'b s n f -> b f (n s)')
            out = torch.cat([out, out_att], 1)
        out = torch.cat([out, h], 1)
        out = self.activation(self.lin_out(out))
        # out = self.lin_out(out)
        out = torch.cat([out, h], 1)
        return self.read_out(out), out


class GRIL(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 u_size=None,
                 n_layers=1,
                 dropout=0.,
                 kernel_size=2,
                 decoder_order=1,
                 global_att=False,
                 support_len=2,
                 n_nodes=None,
                 layer_norm=False):
        super(GRIL, self).__init__()
        self.input_size = int(input_size)
        self.hidden_size = int(hidden_size)
        self.u_size = int(u_size) if u_size is not None else 0
        self.n_layers = int(n_layers)
        rnn_input_size = 2 * self.input_size + self.u_size  # input + mask + (eventually) exogenous

        # Spatio-temporal encoder (rnn_input_size -> hidden_size)
        self.cells = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(self.n_layers):
            self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size,
                                        num_units=self.hidden_size, support_len=support_len, order=kernel_size))
            if layer_norm:
                self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size))
            else:
                self.norms.append(nn.Identity())
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

        # Fist stage readout
        self.first_stage = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1)

        # Spatial decoder (rnn_input_size + hidden_size -> hidden_size)
        self.spatial_decoder = SpatialDecoder(d_in=rnn_input_size + self.hidden_size,
                                              d_model=self.hidden_size,
                                              d_out=self.input_size,
                                              support_len=2,
                                              order=decoder_order,
                                              attention_block=global_att)

        # Hidden state initialization embedding
        if n_nodes is not None:
            self.h0 = self.init_hidden_states(n_nodes)
        else:
            self.register_parameter('h0', None)

    def init_hidden_states(self, n_nodes):
        h0 = []
        for l in range(self.n_layers):
            std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float))
            vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes))
            h0.append(nn.Parameter(vals))
        return nn.ParameterList(h0)

    def get_h0(self, x):
        if self.h0 is not None:
            return [h.expand(x.shape[0], -1, -1) for h in self.h0]
        return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers

    def update_state(self, x, h, adj):
        rnn_in = x
        for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)):
            rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj))
            if self.dropout is not None and layer < (self.n_layers - 1):
                rnn_in = self.dropout(rnn_in)
        return h

    def forward(self, x, adj, mask=None, u=None, h=None, cached_support=False):
        # x:[batch, features, nodes, steps]
        *_, steps = x.size()

        # infer all valid if mask is None
        if mask is None:
            mask = torch.ones_like(x, dtype=torch.uint8)

        # init hidden state using node embedding or the empty state
        if h is None:
            h = self.get_h0(x)
        elif not isinstance(h, list):
            h = [*h]

        # Temporal conv
        predictions, imputations, states = [], [], []
        representations = []
        for step in range(steps):
            x_s = x[..., step]
            m_s = mask[..., step]
            h_s = h[-1]
            u_s = u[..., step] if u is not None else None
            # firstly impute missing values with predictions from state
            xs_hat_1 = self.first_stage(h_s)
            # fill missing values in input with prediction
            x_s = torch.where(m_s, x_s, xs_hat_1)
            # prepare inputs
            # retrieve maximum information from neighbors
            xs_hat_2, repr_s = self.spatial_decoder(x=x_s, m=m_s, h=h_s, u=u_s, adj=adj,
                                                    cached_support=cached_support)  # receive messages from neighbors (no self-loop!)
            # readout of imputation state + mask to retrieve imputations
            # prepare inputs
            x_s = torch.where(m_s, x_s, xs_hat_2)
            inputs = [x_s, m_s]
            if u_s is not None:
                inputs.append(u_s)
            inputs = torch.cat(inputs, dim=1)  # x_hat_2 + mask + exogenous
            # update state with original sequence filled using imputations
            h = self.update_state(inputs, h, adj)
            # store imputations and states
            imputations.append(xs_hat_2)
            predictions.append(xs_hat_1)
            states.append(torch.stack(h, dim=0))
            representations.append(repr_s)

        # Aggregate outputs -> [batch, features, nodes, steps]
        imputations = torch.stack(imputations, dim=-1)
        predictions = torch.stack(predictions, dim=-1)
        states = torch.stack(states, dim=-1)
        representations = torch.stack(representations, dim=-1)

        return imputations, predictions, representations, states


class BiGRIL(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 ff_size,
                 ff_dropout,
                 n_layers=1,
                 dropout=0.,
                 n_nodes=None,
                 support_len=2,
                 kernel_size=2,
                 decoder_order=1,
                 global_att=False,
                 u_size=0,
                 embedding_size=0,
                 layer_norm=False,
                 merge='mlp'):
        super(BiGRIL, self).__init__()
        self.fwd_rnn = GRIL(input_size=input_size,
                            hidden_size=hidden_size,
                            n_layers=n_layers,
                            dropout=dropout,
                            n_nodes=n_nodes,
                            support_len=support_len,
                            kernel_size=kernel_size,
                            decoder_order=decoder_order,
                            global_att=global_att,
                            u_size=u_size,
                            layer_norm=layer_norm)
        self.bwd_rnn = GRIL(input_size=input_size,
                            hidden_size=hidden_size,
                            n_layers=n_layers,
                            dropout=dropout,
                            n_nodes=n_nodes,
                            support_len=support_len,
                            kernel_size=kernel_size,
                            decoder_order=decoder_order,
                            global_att=global_att,
                            u_size=u_size,
                            layer_norm=layer_norm)

        if n_nodes is None:
            embedding_size = 0
        if embedding_size > 0:
            self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes))
            nn.init.kaiming_normal_(self.emb, nonlinearity='relu')
        else:
            self.register_parameter('emb', None)

        if merge == 'mlp':
            self._impute_from_states = True
            self.out = nn.Sequential(
                nn.Conv2d(in_channels=4 * hidden_size + input_size + embedding_size,
                          out_channels=ff_size, kernel_size=1),
                nn.ReLU(),
                nn.Dropout(ff_dropout),
                nn.Conv2d(in_channels=ff_size, out_channels=input_size, kernel_size=1)
            )
        elif merge in ['mean', 'sum', 'min', 'max']:
            self._impute_from_states = False
            self.out = getattr(torch, merge)
        else:
            raise ValueError("Merge option %s not allowed." % merge)
        self.supp = None

    def forward(self, x, adj, mask=None, u=None, cached_support=False):
        if cached_support and (self.supp is not None):
            supp = self.supp
        else:
            supp = SpatialConvOrderK.compute_support(adj, x.device)
            self.supp = supp if cached_support else None
        # Forward
        fwd_out, fwd_pred, fwd_repr, _ = self.fwd_rnn(x, supp, mask=mask, u=u, cached_support=cached_support)
        # Backward
        rev_x, rev_mask, rev_u = [reverse_tensor(tens) for tens in (x, mask, u)]
        *bwd_res, _ = self.bwd_rnn(rev_x, supp, mask=rev_mask, u=rev_u, cached_support=cached_support)
        bwd_out, bwd_pred, bwd_repr = [reverse_tensor(res) for res in bwd_res]

        if self._impute_from_states:
            inputs = [fwd_repr, bwd_repr, mask]
            if self.emb is not None:
                b, *_, s = fwd_repr.shape  # fwd_h: [batches, channels, nodes, steps]
                inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)]  # stack emb for batches and steps
            imputation = torch.cat(inputs, dim=1)
            imputation = self.out(imputation)
        else:
            imputation = torch.stack([fwd_out, bwd_out], dim=1)
            imputation = self.out(imputation, dim=1)

        predictions = torch.stack([fwd_out, bwd_out, fwd_pred, bwd_pred], dim=0)

        return imputation, predictions


================================================
FILE: lib/nn/layers/imputation.py
================================================
import math

import torch
from torch import nn
from torch.nn import functional as F


class ImputationLayer(nn.Module):
    def __init__(self, d_in, bias=True):
        super(ImputationLayer, self).__init__()
        self.W = nn.Parameter(torch.Tensor(d_in, d_in))
        if bias:
            self.b = nn.Parameter(torch.Tensor(d_in))
        else:
            self.register_buffer('b', None)
        mask = 1. - torch.eye(d_in)
        self.register_buffer('mask', mask)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
        if self.b is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.b, -bound, bound)

    def forward(self, x):
        # batch, features
        return F.linear(x, self.mask * self.W, self.b)

================================================
FILE: lib/nn/layers/mpgru.py
================================================
import torch
from torch import nn

from .gcrnn import GCGRUCell


class MPGRUImputer(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 ff_size=None,
                 u_size=None,
                 n_layers=1,
                 dropout=0.,
                 kernel_size=2,
                 support_len=2,
                 n_nodes=None,
                 layer_norm=False,
                 autoencoder_mode=False):
        super(MPGRUImputer, self).__init__()
        self.input_size = int(input_size)
        self.hidden_size = int(hidden_size)
        self.ff_size = int(ff_size) if ff_size is not None else 0
        self.u_size = int(u_size) if u_size is not None else 0
        self.n_layers = int(n_layers)
        rnn_input_size = 2 * self.input_size + self.u_size  # input + mask + (eventually) exogenous

        # Spatio-temporal encoder (rnn_input_size -> hidden_size)
        self.cells = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(self.n_layers):
            self.cells.append(GCGRUCell(d_in=rnn_input_size if i == 0 else self.hidden_size,
                                        num_units=self.hidden_size, support_len=support_len, order=kernel_size))
            if layer_norm:
                self.norms.append(nn.GroupNorm(num_groups=1, num_channels=self.hidden_size))
            else:
                self.norms.append(nn.Identity())
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

        # Readout
        if self.ff_size:
            self.pred_readout = nn.Sequential(
                nn.Conv1d(in_channels=self.hidden_size, out_channels=self.ff_size, kernel_size=1),
                nn.PReLU(),
                nn.Conv1d(in_channels=self.ff_size, out_channels=self.input_size, kernel_size=1)
            )
        else:
            self.pred_readout = nn.Conv1d(in_channels=self.hidden_size, out_channels=self.input_size, kernel_size=1)

        # Hidden state initialization embedding
        if n_nodes is not None:
            self.h0 = self.init_hidden_states(n_nodes)
        else:
            self.register_parameter('h0', None)

        self.autoencoder_mode = autoencoder_mode

    def init_hidden_states(self, n_nodes):
        h0 = []
        for l in range(self.n_layers):
            std = 1. / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float))
            vals = torch.distributions.Normal(0, std).sample((self.hidden_size, n_nodes))
            h0.append(nn.Parameter(vals))
        return nn.ParameterList(h0)

    def get_h0(self, x):
        if self.h0 is not None:
            return [h.expand(x.shape[0], -1, -1) for h in self.h0]
        return [torch.zeros(size=(x.shape[0], self.hidden_size, x.shape[2])).to(x.device)] * self.n_layers

    def update_state(self, x, h, adj):
        rnn_in = x
        for layer, (cell, norm) in enumerate(zip(self.cells, self.norms)):
            rnn_in = h[layer] = norm(cell(rnn_in, h[layer], adj))
            if self.dropout is not None and layer < (self.n_layers - 1):
                rnn_in = self.dropout(rnn_in)
        return h

    def forward(self, x, adj, mask=None, u=None, h=None):
        # x:[batch, features, nodes, steps]
        *_, steps = x.size()

        # infer all valid if mask is None
        if mask is None:
            mask = torch.ones_like(x, dtype=torch.uint8)

        # init hidden state using node embedding or the empty state
        if h is None:
            h = self.get_h0(x)
        elif not isinstance(h, list):
            h = [*h]

        # Temporal conv
        predictions, states = [], []
        for step in range(steps):
            x_s = x[..., step]
            m_s = mask[..., step]
            h_s = h[-1]
            u_s = u[..., step] if u is not None else None
            # impute missing values with predictions from state
            x_s_hat = self.pred_readout(h_s)
            # store imputations and state
            predictions.append(x_s_hat)
            states.append(torch.stack(h, dim=0))
            # fill missing values in input with prediction
            x_s = torch.where(m_s, x_s, x_s_hat)
            inputs = [x_s, m_s]
            if u_s is not None:
                inputs.append(u_s)
            inputs = torch.cat(inputs, dim=1)  # x_hat complemented + mask + exogenous
            # update state with original sequence filled using imputations
            h = self.update_state(inputs, h, adj)

        # In autoencoder mode use states after input processing
        if self.autoencoder_mode:
            states = states[1:] + [torch.stack(h, dim=0)]

        # Aggregate outputs -> [batch, features, nodes, steps]
        predictions = torch.stack(predictions, dim=-1)
        states = torch.stack(states, dim=-1)

        return predictions, states


================================================
FILE: lib/nn/layers/rits.py
================================================
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter

from ..utils.ops import reverse_tensor


class FeatureRegression(nn.Module):
    def __init__(self, input_size):
        super(FeatureRegression, self).__init__()
        self.W = Parameter(torch.Tensor(input_size, input_size))
        self.b = Parameter(torch.Tensor(input_size))

        m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size)
        self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.shape[0])
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, x):
        z_h = F.linear(x, self.W * Variable(self.m), self.b)
        return z_h


class TemporalDecay(nn.Module):
    def __init__(self, d_in, d_out, diag=False):
        super(TemporalDecay, self).__init__()
        self.diag = diag
        self.W = Parameter(torch.Tensor(d_out, d_in))
        self.b = Parameter(torch.Tensor(d_out))

        if self.diag:
            assert (d_in == d_out)
            m = torch.eye(d_in, d_in)
            self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.shape[0])
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    @staticmethod
    def compute_delta(mask, freq=1):
        delta = torch.zeros_like(mask).float()
        one_step = torch.tensor(freq, dtype=delta.dtype, device=delta.device)
        for i in range(1, delta.shape[-2]):
            m = mask[..., i - 1, :]
            delta[..., i, :] = m * one_step + (1 - m) * torch.add(delta[..., i - 1, :], freq)
        return delta

    def forward(self, d):
        if self.diag:
            gamma = F.relu(F.linear(d, self.W * Variable(self.m), self.b))
        else:
            gamma = F.relu(F.linear(d, self.W, self.b))
        gamma = torch.exp(-gamma)
        return gamma


class RITS(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size=64):
        super(RITS, self).__init__()
        self.input_size = int(input_size)
        self.hidden_size = int(hidden_size)

        self.rnn_cell = nn.LSTMCell(2 * self.input_size, self.hidden_size)

        self.temp_decay_h = TemporalDecay(d_in=self.input_size, d_out=self.hidden_size, diag=False)
        self.temp_decay_x = TemporalDecay(d_in=self.input_size, d_out=self.input_size, diag=True)

        self.hist_reg = nn.Linear(self.hidden_size, self.input_size)
        self.feat_reg = FeatureRegression(self.input_size)

        self.weight_combine = nn.Linear(2 * self.input_size, self.input_size)

    def init_hidden_states(self, x):
        return Variable(torch.zeros((x.shape[0], self.hidden_size))).to(x.device)

    def forward(self, x, mask=None, delta=None):
        # x : [batch, steps, features]
        steps = x.shape[-2]

        if mask is None:
            mask = torch.ones_like(x, dtype=torch.uint8)
        if delta is None:
            delta = TemporalDecay.compute_delta(mask)

        # init rnn states
        h = self.init_hidden_states(x)
        c = self.init_hidden_states(x)

        imputation = []
        predictions = []
        for step in range(steps):
            d = delta[:, step, :]
            m = mask[:, step, :]
            x_s = x[:, step, :]

            gamma_h = self.temp_decay_h(d)

            # history prediction
            x_h = self.hist_reg(h)
            x_c = m * x_s + (1 - m) * x_h
            h = h * gamma_h

            # feature prediction
            z_h = self.feat_reg(x_c)

            # predictions combination
            gamma_x = self.temp_decay_x(d)
            alpha = self.weight_combine(torch.cat([gamma_x, m], dim=1))
            alpha = torch.sigmoid(alpha)
            c_h = alpha * z_h + (1 - alpha) * x_h

            c_c = m * x_s + (1 - m) * c_h
            inputs = torch.cat([c_c, m], dim=1)
            h, c = self.rnn_cell(inputs, (h, c))

            imputation.append(c_c)
            predictions.append(torch.stack((c_h, z_h, x_h), dim=0))

        # imputation -> [batch, steps, features]
        imputation = torch.stack(imputation, dim=-2)
        # predictions -> [predictions, batch, steps, features]
        predictions = torch.stack(predictions, dim=-2)
        c_h, z_h, x_h = predictions

        return imputation, (c_h, z_h, x_h)


class BRITS(nn.Module):

    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.rits_fwd = RITS(input_size, hidden_size)
        self.rits_bwd = RITS(input_size, hidden_size)

    def forward(self, x, mask=None):
        # x: [batches, steps, features]
        # forward
        imp_fwd, pred_fwd = self.rits_fwd(x, mask)
        # backward
        x_bwd = reverse_tensor(x, axis=1)
        mask_bwd = reverse_tensor(mask, axis=1) if mask is not None else None
        imp_bwd, pred_bwd = self.rits_bwd(x_bwd, mask_bwd)
        imp_bwd, pred_bwd = reverse_tensor(imp_bwd, axis=1), [reverse_tensor(pb, axis=1) for pb in pred_bwd]
        # stack into shape = [batch, directions, steps, features]
        imputation = torch.stack([imp_fwd, imp_bwd], dim=1)
        predictions = [torch.stack([pf, pb], dim=1) for pf, pb in zip(pred_fwd, pred_bwd)]
        c_h, z_h, x_h = predictions

        return imputation, (c_h, z_h, x_h)

    @staticmethod
    def consistency_loss(imp_fwd, imp_bwd):
        loss = 0.1 * torch.abs(imp_fwd - imp_bwd).mean()
        return loss


================================================
FILE: lib/nn/layers/spatial_attention.py
================================================
import torch.nn as nn

from einops import rearrange


class SpatialAttention(nn.Module):
    def __init__(self, d_in, d_model, nheads, dropout=0.):
        super(SpatialAttention, self).__init__()
        self.lin_in = nn.Linear(d_in, d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout)

    def forward(self, x, att_mask=None, **kwargs):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        b, s, n, f = x.size()
        x = rearrange(x, 'b s n f -> n (b s) f')
        x = self.lin_in(x)
        x = self.self_attn(x, x, x, attn_mask=att_mask)[0]
        x = rearrange(x, 'n (b s) f -> b s n f', b=b, s=s)
        return x


================================================
FILE: lib/nn/layers/spatial_conv.py
================================================
import torch
from torch import nn

from ... import epsilon


class SpatialConvOrderK(nn.Module):
    """
    Spatial convolution of order K with possibly different diffusion matrices (useful for directed graphs)

    Efficient implementation inspired from graph-wavenet codebase
    """

    def __init__(self, c_in, c_out, support_len=3, order=2, include_self=True):
        super(SpatialConvOrderK, self).__init__()
        self.include_self = include_self
        c_in = (order * support_len + (1 if include_self else 0)) * c_in
        self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1)
        self.order = order

    @staticmethod
    def compute_support(adj, device=None):
        if device is not None:
            adj = adj.to(device)
        adj_bwd = adj.T
        adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon)
        adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon)
        support = [adj_fwd, adj_bwd]
        return support

    @staticmethod
    def compute_support_orderK(adj, k, include_self=False, device=None):
        if isinstance(adj, (list, tuple)):
            support = adj
        else:
            support = SpatialConvOrderK.compute_support(adj, device)
        supp_k = []
        for a in support:
            ak = a
            for i in range(k - 1):
                ak = torch.matmul(ak, a.T)
                if not include_self:
                    ak.fill_diagonal_(0.)
                supp_k.append(ak)
        return support + supp_k

    def forward(self, x, support):
        # [batch, features, nodes, steps]
        if x.dim() < 4:
            squeeze = True
            x = torch.unsqueeze(x, -1)
        else:
            squeeze = False
        out = [x] if self.include_self else []
        if (type(support) is not list):
            support = [support]
        for a in support:
            x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous()
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous()
                out.append(x2)
                x1 = x2

        out = torch.cat(out, dim=1)
        out = self.mlp(out)
        if squeeze:
            out = out.squeeze(-1)
        return out


================================================
FILE: lib/nn/models/__init__.py
================================================
from .grin import GRINet
from .brits import BRITSNet
from .mpgru import MPGRUNet, BiMPGRUNet
from .var import VARImputer
from .rgain import RGAINNet
from .rnn_imputers import BiRNNImputer, RNNImputer


================================================
FILE: lib/nn/models/brits.py
================================================
import torch
from torch import nn

from ..layers import BRITS


class BRITSNet(nn.Module):
    def __init__(self,
                 d_in,
                 d_hidden=64):
        super(BRITSNet, self).__init__()
        self.birits = BRITS(input_size=d_in,
                            hidden_size=d_hidden)

    def forward(self, x, mask=None, **kwargs):
        # x: [batches, steps, features]
        imputations, predictions = self.birits(x, mask=mask)
        # predictions: [batch, directions, steps, features] x 3
        out = torch.mean(imputations, dim=1)  # -> [batch, steps, features]
        predictions = torch.cat(predictions, dim=1)  # -> [batch, directions * n_predictions, steps, features]
        # reshape
        imputations = torch.transpose(imputations, 0, 1)  # rearrange(imputations, 'b d s f -> d b s f')
        predictions = torch.transpose(predictions, 0, 1)  # rearrange(predictions, 'b d s f -> d b s f')
        return out, imputations, predictions

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-in', type=int)
        parser.add_argument('--d-hidden', type=int, default=64)
        return parser


================================================
FILE: lib/nn/models/grin.py
================================================
import torch
from einops import rearrange
from torch import nn

from ..layers import BiGRIL
from ...utils.parser_utils import str_to_bool


class GRINet(nn.Module):
    def __init__(self,
                 adj,
                 d_in,
                 d_hidden,
                 d_ff,
                 ff_dropout,
                 n_layers=1,
                 kernel_size=2,
                 decoder_order=1,
                 global_att=False,
                 d_u=0,
                 d_emb=0,
                 layer_norm=False,
                 merge='mlp',
                 impute_only_holes=True):
        super(GRINet, self).__init__()
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.d_u = int(d_u) if d_u is not None else 0
        self.d_emb = int(d_emb) if d_emb is not None else 0
        self.register_buffer('adj', torch.tensor(adj).float())
        self.impute_only_holes = impute_only_holes

        self.bigrill = BiGRIL(input_size=self.d_in,
                              ff_size=d_ff,
                              ff_dropout=ff_dropout,
                              hidden_size=self.d_hidden,
                              embedding_size=self.d_emb,
                              n_nodes=self.adj.shape[0],
                              n_layers=n_layers,
                              kernel_size=kernel_size,
                              decoder_order=decoder_order,
                              global_att=global_att,
                              u_size=self.d_u,
                              layer_norm=layer_norm,
                              merge=merge)

    def forward(self, x, mask=None, u=None, **kwargs):
        # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]
        x = rearrange(x, 'b s n c -> b c n s')
        if mask is not None:
            mask = rearrange(mask, 'b s n c -> b c n s')

        if u is not None:
            u = rearrange(u, 'b s n c -> b c n s')

        # imputation: [batches, channels, nodes, steps] prediction: [4, batches, channels, nodes, steps]
        imputation, prediction = self.bigrill(x, self.adj, mask=mask, u=u, cached_support=self.training)
        # In evaluation stage impute only missing values
        if self.impute_only_holes and not self.training:
            imputation = torch.where(mask, x, imputation)
        # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]
        imputation = torch.transpose(imputation, -3, -1)
        prediction = torch.transpose(prediction, -3, -1)
        if self.training:
            return imputation, prediction
        return imputation

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-hidden', type=int, default=64)
        parser.add_argument('--d-ff', type=int, default=64)
        parser.add_argument('--ff-dropout', type=int, default=0.)
        parser.add_argument('--n-layers', type=int, default=1)
        parser.add_argument('--kernel-size', type=int, default=2)
        parser.add_argument('--decoder-order', type=int, default=1)
        parser.add_argument('--d-u', type=int, default=0)
        parser.add_argument('--d-emb', type=int, default=8)
        parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)
        parser.add_argument('--global-att', type=str_to_bool, nargs='?', const=True, default=False)
        parser.add_argument('--merge', type=str, default='mlp')
        parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)
        return parser


================================================
FILE: lib/nn/models/mpgru.py
================================================
import torch
from einops import rearrange
from torch import nn

from ..layers import MPGRUImputer, SpatialConvOrderK
from ..utils.ops import reverse_tensor
from ...utils.parser_utils import str_to_bool


class MPGRUNet(nn.Module):
    def __init__(self,
                 adj,
                 d_in,
                 d_hidden,
                 d_ff=0,
                 d_u=0,
                 n_layers=1,
                 dropout=0.,
                 kernel_size=2,
                 support_len=2,
                 layer_norm=False,
                 impute_only_holes=True):
        super(MPGRUNet, self).__init__()
        self.register_buffer('adj', torch.tensor(adj).float())
        n_nodes = adj.shape[0]
        self.gcgru = MPGRUImputer(input_size=d_in,
                                  hidden_size=d_hidden,
                                  ff_size=d_ff,
                                  u_size=d_u,
                                  n_layers=n_layers,
                                  dropout=dropout,
                                  kernel_size=kernel_size,
                                  support_len=support_len,
                                  layer_norm=layer_norm,
                                  n_nodes=n_nodes)
        self.impute_only_holes = impute_only_holes

    def forward(self, x, mask=None, u=None, h=None):
        # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]
        x = rearrange(x, 'b s n c -> b c n s')
        if mask is not None:
            mask = rearrange(mask, 'b s n c -> b c n s')
        if u is not None:
            u = rearrange(u, 'b s n c -> b c n s')

        adj = SpatialConvOrderK.compute_support(self.adj, x.device)
        imputation, _ = self.gcgru(x, adj, mask=mask, u=u, h=h)

        # In evaluation stage impute only missing values
        if self.impute_only_holes and not self.training:
            imputation = torch.where(mask, x, imputation)

        # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]
        imputation = rearrange(imputation, 'b c n s -> b s n c')

        return imputation

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-hidden', type=int, default=64)
        parser.add_argument('--d-ff', type=int, default=64)
        parser.add_argument('--n-layers', type=int, default=1)
        parser.add_argument('--kernel-size', type=int, default=2)
        parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)
        parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)
        parser.add_argument('--dropout', type=float, default=0.)
        return parser


class BiMPGRUNet(nn.Module):
    def __init__(self,
                 adj,
                 d_in,
                 d_hidden,
                 d_ff=0,
                 d_u=0,
                 n_layers=1,
                 dropout=0.,
                 kernel_size=2,
                 support_len=2,
                 layer_norm=False,
                 embedding_size=0,
                 merge='mlp',
                 impute_only_holes=True,
                 autoencoder_mode=False):
        super(BiMPGRUNet, self).__init__()
        self.register_buffer('adj', torch.tensor(adj).float())
        n_nodes = adj.shape[0]
        self.gcgru_fwd = MPGRUImputer(input_size=d_in,
                                      hidden_size=d_hidden,
                                      u_size=d_u,
                                      n_layers=n_layers,
                                      dropout=dropout,
                                      kernel_size=kernel_size,
                                      support_len=support_len,
                                      layer_norm=layer_norm,
                                      n_nodes=n_nodes,
                                      autoencoder_mode=autoencoder_mode)
        self.gcgru_bwd = MPGRUImputer(input_size=d_in,
                                      hidden_size=d_hidden,
                                      u_size=d_u,
                                      n_layers=n_layers,
                                      dropout=dropout,
                                      kernel_size=kernel_size,
                                      support_len=support_len,
                                      layer_norm=layer_norm,
                                      n_nodes=n_nodes,
                                      autoencoder_mode=autoencoder_mode)
        self.impute_only_holes = impute_only_holes

        if n_nodes is None:
            embedding_size = 0
        if embedding_size > 0:
            self.emb = nn.Parameter(torch.empty(embedding_size, n_nodes))
            nn.init.kaiming_normal_(self.emb, nonlinearity='relu')
        else:
            self.register_parameter('emb', None)

        if merge == 'mlp':
            self._impute_from_states = True
            self.out = nn.Sequential(
                nn.Conv2d(in_channels=2 * d_hidden + d_in + embedding_size,
                          out_channels=d_ff, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=d_ff, out_channels=d_in, kernel_size=1)
            )
        elif merge in ['mean', 'sum', 'min', 'max']:
            self._impute_from_states = False
            self.out = getattr(torch, merge)
        else:
            raise ValueError("Merge option %s not allowed." % merge)

    def forward(self, x, mask=None, u=None, h=None):
        # x: [batches, steps, nodes, channels] -> [batches, channels, nodes, steps]
        x = rearrange(x, 'b s n c -> b c n s')
        if mask is not None:
            mask = rearrange(mask, 'b s n c -> b c n s')
        if u is not None:
            u = rearrange(u, 'b s n c -> b c n s')

        adj = SpatialConvOrderK.compute_support(self.adj, x.device)

        # Forward
        fwd_pred, fwd_states = self.gcgru_fwd(x, adj, mask=mask, u=u)
        # Backward
        rev_x, rev_mask, rev_u = [reverse_tensor(tens, axis=-1) for tens in (x, mask, u)]
        bwd_res = self.gcgru_bwd(rev_x, adj, mask=rev_mask, u=rev_u)
        bwd_pred, bwd_states = [reverse_tensor(res, axis=-1) for res in bwd_res]

        if self._impute_from_states:
            inputs = [fwd_states[-1], bwd_states[-1], mask]  # take only state of last gcgru layer
            if self.emb is not None:
                b, *_, s = x.shape  # fwd_h: [batches, channels, nodes, steps]
                inputs += [self.emb.view(1, *self.emb.shape, 1).expand(b, -1, -1, s)]  # stack emb for batches and steps
            imputation = torch.cat(inputs, dim=1)
            imputation = self.out(imputation)
        else:
            imputation = torch.stack([fwd_pred, bwd_pred], dim=1)
            imputation = self.out(imputation, dim=1)

        # In evaluation stage impute only missing values
        if self.impute_only_holes and not self.training:
            imputation = torch.where(mask, x, imputation)

        # out: [batches, channels, nodes, steps] -> [batches, steps, nodes, channels]
        imputation = rearrange(imputation, 'b c n s -> b s n c')

        return imputation

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-hidden', type=int, default=64)
        parser.add_argument('--d-ff', type=int, default=64)
        parser.add_argument('--n-layers', type=int, default=1)
        parser.add_argument('--kernel-size', type=int, default=2)
        parser.add_argument('--d-emb', type=int, default=8)
        parser.add_argument('--layer-norm', type=str_to_bool, nargs='?', const=True, default=False)
        parser.add_argument('--merge', type=str, default='mlp')
        parser.add_argument('--impute-only-holes', type=str_to_bool, nargs='?', const=True, default=True)
        parser.add_argument('--dropout', type=float, default=0.)
        parser.add_argument('--autoencoder-mode', type=str_to_bool, nargs='?', const=True, default=False)
        return parser


================================================
FILE: lib/nn/models/rgain.py
================================================
import torch
from torch import nn

from .rnn_imputers import BiRNNImputer
from ...utils.parser_utils import str_to_bool


class Generator(nn.Module):
    def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=True):
        super(Generator, self).__init__()
        self.inject_noise = inject_noise
        self.d_z = d_z if inject_noise else 0
        self.birnn = BiRNNImputer(d_in,
                                  d_model,
                                  d_u=d_z,
                                  concat_mask=True,
                                  detach_inputs=False,
                                  dropout=dropout,
                                  state_init='zero')

    def forward(self, x, mask):
        if self.inject_noise:
            z = torch.rand(x.size(0), x.size(1), self.d_z, device=x.device) * 0.1
        else:
            z = None
        return self.birnn(x, mask, u=z)


class Discriminator(torch.nn.Module):
    def __init__(self, d_in, d_model, dropout=0.):
        super(Discriminator, self).__init__()
        self.birnn = nn.GRU(2 * d_in, d_model, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.read_out = nn.Linear(2 * d_model, d_in)

    def forward(self, x, h):
        x_in = torch.cat([x, h], dim=-1)
        out, _ = self.birnn(x_in)
        logits = self.read_out(self.dropout(out))
        return logits


class RGAINNet(torch.nn.Module):
    def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=False, k=5):
        super(RGAINNet, self).__init__()
        self.inject_noise = inject_noise
        self.k = k
        self.generator = Generator(d_in, d_model, d_z=d_z, dropout=dropout, inject_noise=inject_noise)
        self.discriminator = Discriminator(d_in, d_model, dropout)

    def forward(self, x, mask, **kwargs):
        if not self.training and self.inject_noise:
            res = []
            for _ in range(self.k):
                res.append(self.generator(x, mask)[0])
            return torch.stack(res, 0).mean(0),

        return self.generator(x, mask)

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-in', type=int)
        parser.add_argument('--d-model', type=int, default=None)
        parser.add_argument('--d-z', type=int, default=8)
        parser.add_argument('--k', type=int, default=5)
        parser.add_argument('--inject-noise', type=str_to_bool, nargs='?', const=True, default=False)
        parser.add_argument('--dropout', type=float, default=0.)
        return parser


================================================
FILE: lib/nn/models/rnn_imputers.py
================================================
import torch
from torch import nn

from ..utils.ops import reverse_tensor


class RNNImputer(nn.Module):
    """Fill the blanks with a 1-step-ahead GRU predictor."""

    def __init__(self, d_in, d_model, concat_mask=True, detach_inputs=False, state_init='zero', d_u=0):
        super(RNNImputer, self).__init__()
        self.concat_mask = concat_mask
        self.detach_inputs = detach_inputs
        self.state_init = state_init
        self.d_model = d_model
        self.input_dim = d_in + d_u if not concat_mask else 2 * d_in + d_u
        self.rnn_cell = nn.GRUCell(self.input_dim, d_model)
        self.read_out = nn.Linear(d_model, d_in)

    def init_hidden_state(self, x):
        if self.state_init == 'zero':
            return torch.zeros((x.size(0), self.d_model), device=x.device, dtype=x.dtype)
        if self.state_init == 'noise':
            return torch.randn(x.size(0), self.d_model, device=x.device, dtype=x.dtype)

    def _preprocess_input(self, x, x_hat, m, u):
        if self.detach_inputs:
            x_p = torch.where(m, x, x_hat.detach())
        else:
            x_p = torch.where(m, x, x_hat)

        if u is not None:
            x_p = torch.cat([x_p, u], -1)
        if self.concat_mask:
            x_p = torch.cat([x_p, m], -1)
        return x_p

    def forward(self, x, mask, u=None, return_hidden=False):
        # x: [batches, steps, features]
        steps = x.size(1)
        # ensure masked values are not visible
        x = torch.where(mask, x, torch.zeros_like(x))

        h = self.init_hidden_state(x)
        x_hat = self.read_out(h)
        hs = [h]
        preds = [x_hat]
        for s in range(steps - 1):
            u_t = None if u is None else u[:, s]
            x_t = self._preprocess_input(x[:, s], x_hat, mask[:, s], u_t)
            h = self.rnn_cell(x_t, h)
            x_hat = self.read_out(h)
            hs.append(h)
            preds.append(x_hat)

        x_hat = torch.stack(preds, 1)
        h = torch.stack(hs, 1)
        if return_hidden:
            return x_hat, h
        return x_hat

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-in', type=int)
        parser.add_argument('--d-model', type=int, default=None)
        return parser


class BiRNNImputer(nn.Module):
    """Fill the blanks with a 1-step-ahead GRU predictor."""

    def __init__(self, d_in, d_model, dropout=0., concat_mask=True, detach_inputs=False, state_init='zero', d_u=0):
        super(BiRNNImputer, self).__init__()
        self.d_model = d_model
        self.fwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init,
                                  d_u=d_u)
        self.bwd_rnn = RNNImputer(d_in, d_model, concat_mask, detach_inputs=detach_inputs, state_init=state_init,
                                  d_u=d_u)
        self.dropout = nn.Dropout(dropout)
        self.read_out = nn.Linear(2 * d_model, d_in)

    def forward(self, x, mask, u=None, return_hidden=False):
        # x: [batches, steps, features]
        x_hat_fwd, h_fwd = self.fwd_rnn(x, mask, u=u, return_hidden=True)
        x_hat_bwd, h_bwd = self.bwd_rnn(reverse_tensor(x, 1),
                                        reverse_tensor(mask, 1),
                                        u=reverse_tensor(u, 1) if u is not None else None,
                                        return_hidden=True)
        x_hat_bwd = reverse_tensor(x_hat_bwd, 1)
        h_bwd = reverse_tensor(h_bwd, 1)
        h = self.dropout(torch.cat([h_fwd, h_bwd], -1))
        x_hat = self.read_out(h)
        if return_hidden:
            return (x_hat, x_hat_fwd, x_hat_bwd), h
        return x_hat, x_hat_fwd, x_hat_bwd

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--d-in', type=int)
        parser.add_argument('--d-model', type=int, default=None)
        parser.add_argument('--dropout', type=float, default=0.)
        return parser


================================================
FILE: lib/nn/models/var.py
================================================
import torch
from einops import rearrange
from torch import nn

from lib import epsilon


class VAR(nn.Module):
    def __init__(self, order, d_in, d_out=None, steps_ahead=1, bias=True):
        super(VAR, self).__init__()
        self.order = order
        self.d_in = d_in
        self.d_out = d_out if d_out is not None else d_in
        self.steps_ahead = steps_ahead
        self.lin = nn.Linear(order * d_in, steps_ahead * self.d_out, bias=bias)

    def forward(self, x):
        # x: [batches, steps, features]
        x = rearrange(x, 'b s f -> b (s f)')
        out = self.lin(x)
        out = rearrange(out, 'b (s f) -> b s f', s=self.steps_ahead, f=self.d_out)
        return out

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--order', type=int)
        parser.add_argument('--d-in', type=int)
        parser.add_argument('--d-out', type=int, default=None)
        parser.add_argument('--steps-ahead', type=int, default=1)
        return parser


class VARImputer(nn.Module):
    """Fill the blanks with a 1-step-ahead VAR predictor."""

    def __init__(self, order, d_in, padding='mean'):
        super(VARImputer, self).__init__()
        assert padding in ['mean', 'zero']
        self.order = order
        self.padding = padding
        self.predictor = VAR(order, d_in, d_out=d_in, steps_ahead=1)

    def forward(self, x, mask=None):
        # x: [batches, steps, features]
        batch_size, steps, n_feats = x.shape
        if mask is None:
            mask = torch.ones_like(x, dtype=torch.uint8)
        x = x * mask
        # pad input sequence to start filling from first step
        if self.padding == 'mean':
            mean = torch.sum(x, 1) / (torch.sum(mask, 1) + epsilon)
            pad = torch.repeat_interleave(mean.unsqueeze(1), self.order, 1)
        elif self.padding == 'zero':
            pad = torch.zeros((batch_size, self.order, n_feats)).to(x.device)
        x = torch.cat([pad, x], 1)
        # x: [batch, order + steps, features]
        x = [x[:, i] for i in range(x.shape[1])]
        for s in range(steps):
            x_hat = self.predictor(torch.stack(x[s:s + self.order], 1))
            x_hat = x_hat[:, 0]
            x[s + self.order] = torch.where(mask[:, s], x[s + self.order], x_hat)
        x = torch.stack(x[self.order:], 1)  # remove padding
        return x

    @staticmethod
    def add_model_specific_args(parser):
        parser.add_argument('--order', type=int)
        parser.add_argument('--d-in', type=int)
        parser.add_argument("--padding", type=str, default='mean')
        return parser


================================================
FILE: lib/nn/utils/__init__.py
================================================


================================================
FILE: lib/nn/utils/metric_base.py
================================================
from functools import partial

import torch
from pytorch_lightning.metrics import Metric
from torchmetrics.utilities.checks import _check_same_shape


class MaskedMetric(Metric):
    def __init__(self,
                 metric_fn,
                 mask_nans=False,
                 mask_inf=False,
                 compute_on_step=True,
                 dist_sync_on_step=False,
                 process_group=None,
                 dist_sync_fn=None,
                 metric_kwargs=None,
                 at=None):
        super(MaskedMetric, self).__init__(compute_on_step=compute_on_step,
                                           dist_sync_on_step=dist_sync_on_step,
                                           process_group=process_group,
                                           dist_sync_fn=dist_sync_fn)

        if metric_kwargs is None:
            metric_kwargs = dict()
        self.metric_fn = partial(metric_fn, **metric_kwargs)
        self.mask_nans = mask_nans
        self.mask_inf = mask_inf
        if at is None:
            self.at = slice(None)
        else:
            self.at = slice(at, at + 1)
        self.add_state('value', dist_reduce_fx='sum', default=torch.tensor(0.).float())
        self.add_state('numel', dist_reduce_fx='sum', default=torch.tensor(0))

    def _check_mask(self, mask, val):
        if mask is None:
            mask = torch.ones_like(val).byte()
        else:
            _check_same_shape(mask, val)
        if self.mask_nans:
            mask = mask * ~torch.isnan(val)
        if self.mask_inf:
            mask = mask * ~torch.isinf(val)
        return mask

    def _compute_masked(self, y_hat, y, mask):
        _check_same_shape(y_hat, y)
        val = self.metric_fn(y_hat, y)
        mask = self._check_mask(mask, val)
        val = torch.where(mask, val, torch.tensor(0., device=val.device).float())
        return val.sum(), mask.sum()

    def _compute_std(self, y_hat, y):
        _check_same_shape(y_hat, y)
        val = self.metric_fn(y_hat, y)
        return val.sum(), val.numel()

    def is_masked(self, mask):
        return self.mask_inf or self.mask_nans or (mask is not None)

    def update(self, y_hat, y, mask=None):
        y_hat = y_hat[:, self.at]
        y = y[:, self.at]
        if mask is not None:
            mask = mask[:, self.at]
        if self.is_masked(mask):
            val, numel = self._compute_masked(y_hat, y, mask)
        else:
            val, numel = self._compute_std(y_hat, y)
        self.value += val
        self.numel += numel

    def compute(self):
        if self.numel > 0:
            return self.value / self.numel
        return self.value


================================================
FILE: lib/nn/utils/metrics.py
================================================
from .metric_base import MaskedMetric
from .ops import mape
from torch.nn import functional as F
import torch

from torchmetrics.utilities.checks import _check_same_shape

from ... import epsilon


class MaskedMAE(MaskedMetric):
    def __init__(self,
                 mask_nans=False,
                 mask_inf=False,
                 compute_on_step=True,
                 dist_sync_on_step=False,
                 process_group=None,
                 dist_sync_fn=None,
                 at=None):
        super(MaskedMAE, self).__init__(metric_fn=F.l1_loss,
                                        mask_nans=mask_nans,
                                        mask_inf=mask_inf,
                                        compute_on_step=compute_on_step,
                                        dist_sync_on_step=dist_sync_on_step,
                                        process_group=process_group,
                                        dist_sync_fn=dist_sync_fn,
                                        metric_kwargs={'reduction': 'none'},
                                        at=at)


class MaskedMAPE(MaskedMetric):
    def __init__(self,
                 mask_nans=False,
                 compute_on_step=True,
                 dist_sync_on_step=False,
                 process_group=None,
                 dist_sync_fn=None,
                 at=None):
        super(MaskedMAPE, self).__init__(metric_fn=mape,
                                         mask_nans=mask_nans,
                                         mask_inf=True,
                                         compute_on_step=compute_on_step,
                                         dist_sync_on_step=dist_sync_on_step,
                                         process_group=process_group,
                                         dist_sync_fn=dist_sync_fn,
                                         at=at)


class MaskedMSE(MaskedMetric):
    def __init__(self,
                 mask_nans=False,
                 compute_on_step=True,
                 dist_sync_on_step=False,
                 process_group=None,
                 dist_sync_fn=None,
                 at=None):
        super(MaskedMSE, self).__init__(metric_fn=F.mse_loss,
                                        mask_nans=mask_nans,
                                        mask_inf=True,
                                        compute_on_step=compute_on_step,
                                        dist_sync_on_step=dist_sync_on_step,
                                        process_group=process_group,
                                        dist_sync_fn=dist_sync_fn,
                                        metric_kwargs={'reduction': 'none'},
                                        at=at)


class MaskedMRE(MaskedMetric):
    def __init__(self,
                 mask_nans=False,
                 mask_inf=False,
                 compute_on_step=True,
                 dist_sync_on_step=False,
                 process_group=None,
                 dist_sync_fn=None,
                 at=None):
        super(MaskedMRE, self).__init__(metric_fn=F.l1_loss,
                                        mask_nans=mask_nans,
                                        mask_inf=mask_inf,
                                        compute_on_step=compute_on_step,
                                        dist_sync_on_step=dist_sync_on_step,
                                        process_group=process_group,
                                        dist_sync_fn=dist_sync_fn,
                                        metric_kwargs={'reduction': 'none'},
                                        at=at)
        self.add_state('tot', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float))

    def _compute_masked(self, y_hat, y, mask):
        _check_same_shape(y_hat, y)
        val = self.metric_fn(y_hat, y)
        mask = self._check_mask(mask, val)
        val = torch.where(mask, val, torch.tensor(0., device=y.device, dtype=torch.float))
        y_masked = torch.where(mask, y, torch.tensor(0., device=y.device, dtype=torch.float))
        return val.sum(), mask.sum(), y_masked.sum()

    def _compute_std(self, y_hat, y):
        _check_same_shape(y_hat, y)
        val = self.metric_fn(y_hat, y)
        return val.sum(), val.numel(), y.sum()

    def compute(self):
        if self.tot > epsilon:
            return self.value / self.tot
        return self.value

    def update(self, y_hat, y, mask=None):
        y_hat = y_hat[:, self.at]
        y = y[:, self.at]
        if mask is not None:
            mask = mask[:, self.at]
        if self.is_masked(mask):
            val, numel, tot = self._compute_masked(y_hat, y, mask)
        else:
            val, numel, tot = self._compute_std(y_hat, y)
        self.value += val
        self.numel += numel
        self.tot += tot




================================================
FILE: lib/nn/utils/ops.py
================================================
import torch
import torch.nn.functional as F
from einops import reduce
from torch.autograd import Variable

from ... import epsilon


def mae(y_hat, y, reduction='none'):
    return F.l1_loss(y_hat, y, reduction=reduction)


def mape(y_hat, y):
    return torch.abs((y_hat - y) / y)


def wape_loss(y_hat, y):
    l = torch.abs(y_hat - y)
    return l.sum() / (y.sum() + epsilon)


def smape_loss(y_hat, y):
    c = torch.abs(y) > epsilon
    l_minus = torch.abs(y_hat - y)
    l_plus = torch.abs(y_hat + y) + epsilon
    l = 2 * l_minus / l_plus * c.float()
    return l.sum() / c.sum()


def peak_prediction_loss(y_hat, y, reduction='none'):
    y_max = reduce(y, 'b s n 1 -> b 1 n 1', 'max')
    y_min = reduce(y, 'b s n 1 -> b 1 n 1', 'min')
    target = torch.cat([y_max, y_min], dim=1)
    return F.mse_loss(y_hat, target, reduction=reduction)


def wrap_loss_fn(base_loss):
    def loss_fn(y_hat, y_true, mask=None):
        scaling = 1.
        if mask is not None:
            try:
                loss = base_loss(y_hat, y_true, reduction='none')
            except TypeError:
                loss = base_loss(y_hat, y_true)
            loss = loss * mask
            loss = loss.sum() / (mask.sum() + epsilon)
            # scaling = mask.sum() / torch.numel(mask)
        else:
            loss = base_loss(y_hat, y_true).mean()
        return scaling * loss

    return loss_fn


def rbf_sim(x, gamma, device='cpu'):
    n = x.size()[0]
    a = torch.exp(-gamma * F.pdist(x, 2) ** 2)
    row_idx, col_idx = torch.triu_indices(n, n, 1)
    A = 0.5 * torch.eye(n, n).to(device)
    A[row_idx, col_idx] = a
    return A + A.T


def reverse_tensor(tensor=None, axis=-1):
    if tensor is None:
        return None
    if tensor.dim() <= 1:
        return tensor
    indices = range(tensor.size()[axis])[::-1]
    indices = Variable(torch.LongTensor(indices), requires_grad=False).to(tensor.device)
    return tensor.index_select(axis, indices)


================================================
FILE: lib/utils/__init__.py
================================================
from .utils import *


================================================
FILE: lib/utils/numpy_metrics.py
================================================
import numpy as np


def mae(y_hat, y):
    return np.abs(y_hat - y).mean()


def nmae(y_hat, y):
    delta = np.max(y) - np.min(y) + 1e-8
    return mae(y_hat, y) * 100 / delta


def mape(y_hat, y):
    return 100 * np.abs((y_hat - y) / (y + 1e-8)).mean()


def mse(y_hat, y):
    return np.square(y_hat - y).mean()


def rmse(y_hat, y):
    return np.sqrt(mse(y_hat, y))


def nrmse(y_hat, y):
    delta = np.max(y) - np.min(y) + 1e-8
    return rmse(y_hat, y) * 100 / delta


def nrmse_2(y_hat, y):
    nrmse_ = np.sqrt(np.square(y_hat - y).sum() / np.square(y).sum())
    return nrmse_ * 100


def r2(y_hat, y):
    return 1. - np.square(y_hat - y).sum() / (np.square(y.mean(0) - y).sum())


def masked_mae(y_hat, y, mask):
    err = np.abs(y_hat - y) * mask
    return err.sum() / mask.sum()


def masked_mape(y_hat, y, mask):
    err = np.abs((y_hat - y) / (y + 1e-8)) * mask
    return err.sum() / mask.sum()


def masked_mse(y_hat, y, mask):
    err = np.square(y_hat - y) * mask
    return err.sum() / mask.sum()


def masked_rmse(y_hat, y, mask):
    err = np.square(y_hat - y) * mask
    return np.sqrt(err.sum() / mask.sum())


def masked_mre(y_hat, y, mask):
    err = np.abs(y_hat - y) * mask
    return err.sum() / ((y * mask).sum() + 1e-8)


================================================
FILE: lib/utils/parser_utils.py
================================================
import inspect
from argparse import Namespace, ArgumentParser
from typing import Union


def str_to_bool(value):
    if isinstance(value, bool):
        return value
    if value.lower() in {'false', 'f', '0', 'no', 'n', 'off'}:
        return False
    elif value.lower() in {'true', 't', '1', 'yes', 'y', 'on'}:
        return True
    raise ValueError(f'{value} is not a valid boolean value')


def config_dict_from_args(args):
    """
    Extract a dictionary with the experiment configuration from arguments (necessary to filter TestTube arguments)

    :param args: TTNamespace
    :return: hyparams dict
    """
    keys_to_remove = {'hpc_exp_number', 'trials', 'optimize_parallel', 'optimize_parallel_gpu',
                      'optimize_parallel_cpu', 'generate_trials', 'optimize_trials_parallel_gpu'}
    hparams = {key: v for key, v in args.__dict__.items() if key not in keys_to_remove}
    return hparams


def update_from_config(args: Namespace, config: dict):
    assert set(config.keys()) <= set(vars(args)), f'{set(config.keys()).difference(vars(args))} not in args.'
    args.__dict__.update(config)
    return args


def parse_by_group(parser):
    """
    Create a nested namespace using the groups defined in the argument parser.
    Adapted from https://stackoverflow.com/a/56631542/6524027

    :param args: arguments
    :param parser: the parser
    :return:
    """
    assert isinstance(parser, ArgumentParser)
    args = parser.parse_args()

    # the first two argument groups are 'positional_arguments' and 'optional_arguments'
    pos_group, optional_group = parser._action_groups[0], parser._action_groups[1]
    args_dict = args._get_kwargs()
    pos_optional_arg_names = [arg.dest for arg in pos_group._group_actions] + [arg.dest for arg in
                                                                               optional_group._group_actions]
    pos_optional_args = {name: value for name, value in args_dict if name in pos_optional_arg_names}
    other_group_args = dict()

    # If there are additional argument groups, add them as nested namespaces
    if len(parser._action_groups) > 2:
        for group in parser._action_groups[2:]:
            group_arg_names = [arg.dest for arg in group._group_actions]
            other_group_args[group.title] = Namespace(
                **{name: value for name, value in args_dict if name in group_arg_names})

    # combine the positiona/optional args and the group args
    combined_args = pos_optional_args
    combined_args.update(other_group_args)
    return Namespace(flat=args, **combined_args)


def filter_args(args: Union[Namespace, dict], target_cls, return_dict=False):
    argspec = inspect.getfullargspec(target_cls.__init__)
    target_args = argspec.args
    if isinstance(args, Namespace):
        args = vars(args)
    filtered_args = {k: args[k] for k in target_args if k in args}
    if return_dict:
        return filtered_args
    return Namespace(**filtered_args)


def filter_function_args(args: Union[Namespace, dict], function, return_dict=False):
    argspec = inspect.getfullargspec(function)
    target_args = argspec.args
    if isinstance(args, Namespace):
        args = vars(args)
    filtered_args = {k: args[k] for k in target_args if k in args}
    if return_dict:
        return filtered_args
    return Namespace(**filtered_args)


================================================
FILE: lib/utils/utils.py
================================================
import numpy as np
import pandas as pd

from sklearn.metrics.pairwise import haversine_distances


def sample_mask(shape, p=0.002, p_noise=0., max_seq=1, min_seq=1, rng=None):
    if rng is None:
        rand = np.random.random
        randint = np.random.randint
    else:
        rand = rng.random
        randint = rng.integers
    mask = rand(shape) < p
    for col in range(mask.shape[1]):
        idxs = np.flatnonzero(mask[:, col])
        if not len(idxs):
            continue
        fault_len = min_seq
        if max_seq > min_seq:
            fault_len = fault_len + int(randint(max_seq - min_seq))
        idxs_ext = np.concatenate([np.arange(i, i + fault_len) for i in idxs])
        idxs = np.unique(idxs_ext)
        idxs = np.clip(idxs, 0, shape[0] - 1)
        mask[idxs, col] = True
    mask = mask | (rand(mask.shape) < p_noise)
    return mask.astype('uint8')


def compute_mean(x, index=None):
    """Compute the mean values for each datetime. The mean is first computed hourly over the week of the year.
    Further NaN values are computed using hourly mean over the same month through the years. If other NaN are present,
    they are removed using the mean of the sole hours. Hoping reasonably that there is at least a non-NaN entry of the
    same hour of the NaN datetime in all the dataset."""
    if isinstance(x, np.ndarray) and index is not None:
        shape = x.shape
        x = x.reshape((shape[0], -1))
        df_mean = pd.DataFrame(x, index=index)
    else:
        df_mean = x.copy()
    cond0 = [df_mean.index.year, df_mean.index.isocalendar().week, df_mean.index.hour]
    cond1 = [df_mean.index.year, df_mean.index.month, df_mean.index.hour]
    conditions = [cond0, cond1, cond1[1:], cond1[2:]]
    while df_mean.isna().values.sum() and len(conditions):
        nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
        df_mean = df_mean.fillna(nan_mean)
        conditions = conditions[1:]
    if df_mean.isna().values.sum():
        df_mean = df_mean.fillna(method='ffill')
        df_mean = df_mean.fillna(method='bfill')
    if isinstance(x, np.ndarray):
        df_mean = df_mean.values.reshape(shape)
    return df_mean


def geographical_distance(x=None, to_rad=True):
    """
    Compute the as-the-crow-flies distance between every pair of samples in `x`. The first dimension of each point is
    assumed to be the latitude, the second is the longitude. The inputs is assumed to be in degrees. If it is not the
    case, `to_rad` must be set to False. The dimension of the data must be 2.

    Parameters
    ----------
    x : pd.DataFrame or np.ndarray
        array_like structure of shape (n_samples_2, 2).
    to_rad : bool
        whether to convert inputs to radians (provided that they are in degrees).

    Returns
    -------
    distances :
        The distance between the points in kilometers.
    """
    _AVG_EARTH_RADIUS_KM = 6371.0088

    # Extract values of X if it is a DataFrame, else assume it is 2-dim array of lat-lon pairs
    latlon_pairs = x.values if isinstance(x, pd.DataFrame) else x

    # If the input values are in degrees, convert them in radians
    if to_rad:
        latlon_pairs = np.vectorize(np.radians)(latlon_pairs)

    distances = haversine_distances(latlon_pairs) * _AVG_EARTH_RADIUS_KM

    # Cast response
    if isinstance(x, pd.DataFrame):
        res = pd.DataFrame(distances, x.index, x.index)
    else:
        res = distances

    return res


def infer_mask(df, infer_from='next'):
    """Infer evaluation mask from DataFrame. In the evaluation mask a value is 1 if it is present in the DataFrame and
    absent in the `infer_from` month.

    @param pd.DataFrame df: the DataFrame.
    @param str infer_from: denotes from which month the evaluation value must be inferred.
    Can be either `previous` or `next`.
    @return: pd.DataFrame eval_mask: the evaluation mask for the DataFrame
    """
    mask = (~df.isna()).astype('uint8')
    eval_mask = pd.DataFrame(index=mask.index, columns=mask.columns, data=0).astype('uint8')
    if infer_from == 'previous':
        offset = -1
    elif infer_from == 'next':
        offset = 1
    else:
        raise ValueError('infer_from can only be one of %s' % ['previous', 'next'])
    months = sorted(set(zip(mask.index.year, mask.index.month)))
    length = len(months)
    for i in range(length):
        j = (i + offset) % length
        year_i, month_i = months[i]
        year_j, month_j = months[j]
        mask_j = mask[(mask.index.year == year_j) & (mask.index.month == month_j)]
        mask_i = mask_j.shift(1, pd.DateOffset(months=12 * (year_i - year_j) + (month_i - month_j)))
        mask_i = mask_i[~mask_i.index.duplicated(keep='first')]
        mask_i = mask_i[np.in1d(mask_i.index, mask.index)]
        eval_mask.loc[mask_i.index] = ~mask_i.loc[mask_i.index] & mask.loc[mask_i.index]
    return eval_mask


def prediction_dataframe(y, index, columns=None, aggregate_by='mean'):
    """Aggregate batched predictions in a single DataFrame.

    @param (list or np.ndarray) y: the list of predictions.
    @param (list or np.ndarray) index: the list of time indexes coupled with the predictions.
    @param (list or pd.Index) columns: the columns of the returned DataFrame.
    @param (str or list) aggregate_by: how to aggregate the predictions in case there are more than one for a step.
    - `mean`: take the mean of the predictions
    - `central`: take the prediction at the central position, assuming that the predictions are ordered chronologically
    - `smooth_central`: average the predictions weighted by a gaussian signal with std=1
    - `last`: take the last prediction
    @return: pd.DataFrame df: the evaluation mask for the DataFrame
    """
    dfs = [pd.DataFrame(data=data.reshape(data.shape[:2]), index=idx, columns=columns) for data, idx in zip(y, index)]
    df = pd.concat(dfs)
    preds_by_step = df.groupby(df.index)
    # aggregate according passed methods
    aggr_methods = ensure_list(aggregate_by)
    dfs = []
    for aggr_by in aggr_methods:
        if aggr_by == 'mean':
            dfs.append(preds_by_step.mean())
        elif aggr_by == 'central':
            dfs.append(preds_by_step.aggregate(lambda x: x[int(len(x) // 2)]))
        elif aggr_by == 'smooth_central':
            from scipy.signal import gaussian
            dfs.append(preds_by_step.aggregate(lambda x: np.average(x, weights=gaussian(len(x), 1))))
        elif aggr_by == 'last':
            dfs.append(preds_by_step.aggregate(lambda x: x[0]))  # first imputation has missing value in last position
        else:
            raise ValueError('aggregate_by can only be one of %s' % ['mean', 'central' 'smooth_central', 'last'])
    if isinstance(aggregate_by, str):
        return dfs[0]
    return dfs


def ensure_list(obj):
    if isinstance(obj, (list, tuple)):
        return list(obj)
    else:
        return [obj]


def missing_val_lens(mask):
    m = np.concatenate([np.zeros((1, mask.shape[1])),
                        (~mask.astype('bool')).astype('int'),
                        np.zeros((1, mask.shape[1]))])
    mdiff = np.diff(m, axis=0)
    lens = []
    for c in range(m.shape[1]):
        mj, = mdiff[:, c].nonzero()
        diff = np.diff(mj)[::2]
        lens.extend(list(diff))
    return lens


def disjoint_months(dataset, months=None, synch_mode='window'):
    idxs = np.arange(len(dataset))
    months = ensure_list(months)
    # divide indices according to window or horizon
    if synch_mode == 'window':
        start, end = 0, dataset.window - 1
    elif synch_mode == 'horizon':
        start, end = dataset.horizon_offset, dataset.horizon_offset + dataset.horizon - 1
    else:
        raise ValueError('synch_mode can only be one of %s' % ['window', 'horizon'])
    # after idxs
    start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)
    end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)
    idxs_in_months = start_in_months & end_in_months
    after_idxs = idxs[idxs_in_months]
    # previous idxs
    months = np.setdiff1d(np.arange(1, 13), months)
    start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)
    end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)
    idxs_in_months = start_in_months & end_in_months
    prev_idxs = idxs[idxs_in_months]
    return prev_idxs, after_idxs


def thresholded_gaussian_kernel(x, theta=None, threshold=None, threshold_on_input=False):
    if theta is None:
        theta = np.std(x)
    weights = np.exp(-np.square(x / theta))
    if threshold is not None:
        mask = x > threshold if threshold_on_input else weights < threshold
        weights[mask] = 0.
    return weights


================================================
FILE: requirements.txt
================================================
einops
fancyimpute==0.6
h5py
openpyxl
numpy
pandas
pytorch-lightning==1.4
pyyaml
scikit-learn
scipy
tables
tensorboard
tensorflow==2.5.0
tensorflow-gpu==2.4.0
torch==1.8
torchvision
torchaudio
torchmetrics==0.5


================================================
FILE: scripts/run_baselines.py
================================================
from argparse import ArgumentParser

import numpy as np
from fancyimpute import MatrixFactorization, IterativeImputer
from sklearn.neighbors import kneighbors_graph

from lib import datasets
from lib.utils import numpy_metrics
from lib.utils.parser_utils import str_to_bool

metrics = {
    'mae': numpy_metrics.masked_mae,
    'mse': numpy_metrics.masked_mse,
    'mre': numpy_metrics.masked_mre,
    'mape': numpy_metrics.masked_mape
}


def parse_args():
    parser = ArgumentParser()
    # experiment setting
    parser.add_argument('--datasets', nargs='+', type=str, default=['all'])
    parser.add_argument('--imputers', nargs='+', type=str, default=['all'])
    parser.add_argument('--n-runs', type=int, default=5)
    parser.add_argument('--in-sample', type=str_to_bool, nargs='?', const=True, default=True)
    # SpatialKNNImputer params
    parser.add_argument('--k', type=int, default=10)
    # MFImputer params
    parser.add_argument('--rank', type=int, default=10)
    # MICEImputer params
    parser.add_argument('--mice-iterations', type=int, default=100)
    parser.add_argument('--mice-n-features', type=int, default=None)
    args = parser.parse_args()
    # parse dataset
    if args.datasets[0] == 'all':
        args.datasets = ['air36', 'air', 'bay', 'irish', 'la', 'bay_noise', 'irish_noise', 'la_noise']
    # parse imputers
    if args.imputers[0] == 'all':
        args.imputers = ['mean', 'knn', 'mf', 'mice']
    if not args.in_sample:
        args.imputers = [name for name in args.imputers if name in ['mean', 'mice']]
    return args


class Imputer:
    short_name: str

    def __init__(self, method=None, is_deterministic=True, in_sample=True):
        self.name = self.__class__.__name__
        self.method = method
        self.is_deterministic = is_deterministic
        self.in_sample = in_sample

    def fit(self, x, mask):
        if not self.in_sample:
            x_hat = np.where(mask, x, np.nan)
            return self.method.fit(x_hat)

    def predict(self, x, mask):
        x_hat = np.where(mask, x, np.nan)
        if self.in_sample:
            return self.method.fit_transform(x_hat)
        else:
            return self.method.transform(x_hat)

    def params(self):
        return dict()


class SpatialKNNImputer(Imputer):
    short_name = 'knn'

    def __init__(self, adj, k=20):
        super(SpatialKNNImputer, self).__init__()
        self.k = k
        # normalize sim between [0, 1]
        sim = (adj + adj.min()) / (adj.max() + adj.min())
        knns = kneighbors_graph(1 - sim,
                                n_neighbors=self.k,
                                include_self=False,
                                metric='precomputed').toarray()
        self.knns = knns

    def fit(self, x, mask):
        pass

    def predict(self, x, mask):
        x = np.where(mask, x, 0)
        with np.errstate(divide='ignore', invalid='ignore'):
            y_hat = (x @ self.knns.T) / (mask @ self.knns.T)
        y_hat[~np.isfinite(y_hat)] = x.mean()
        return np.where(mask, x, y_hat)

    def params(self):
        return dict(k=self.k)


class MeanImputer(Imputer):
    short_name = 'mean'

    def fit(self, x, mask):
        d = np.where(mask, x, np.nan)
        self.means = np.nanmean(d, axis=0, keepdims=True)

    def predict(self, x, mask):
        if self.in_sample:
            d = np.where(mask, x, np.nan)
            means = np.nanmean(d, axis=0, keepdims=True)
        else:
            means = self.means
        return np.where(mask, x, means)


class MatrixFactorizationImputer(Imputer):
    short_name = 'mf'

    def __init__(self, rank=10, loss='mae', verbose=0):
        method = MatrixFactorization(rank=rank, loss=loss, verbose=verbose)
        super(MatrixFactorizationImputer, self).__init__(method, is_deterministic=False, in_sample=True)

    def params(self):
        return dict(rank=self.method.rank)


class MICEImputer(Imputer):
    short_name = 'mice'

    def __init__(self, max_iter=100, n_nearest_features=None, in_sample=True, verbose=False):
        method = IterativeImputer(max_iter=max_iter, n_nearest_features=n_nearest_features, verbose=verbose)
        is_deterministic = n_nearest_features is None
        super(MICEImputer, self).__init__(method, is_deterministic=is_deterministic, in_sample=in_sample)

    def params(self):
        return dict(max_iter=self.method.max_iter, k=self.method.n_nearest_features or -1)


def get_dataset(dataset_name):
    if dataset_name[:3] == 'air':
        dataset = datasets.AirQuality(impute_nans=True, small=dataset_name[3:] == '36')
    elif dataset_name == 'bay':
        dataset = datasets.MissingValuesPemsBay()
    elif dataset_name == 'la':
        dataset = datasets.MissingValuesMetrLA()
    elif dataset_name == 'la_noise':
        dataset = datasets.MissingValuesMetrLA(p_fault=0., p_noise=0.25)
    elif dataset_name == 'bay_noise':
        dataset = datasets.MissingValuesPemsBay(p_fault=0., p_noise=0.25)
    else:
        raise ValueError(f"Dataset {dataset_name} not available in this setting.")

    # split in train/test
    if isinstance(dataset, datasets.AirQuality):
        test_slice = np.in1d(dataset.df.index.month, dataset.test_months)
        train_slice = ~test_slice
    else:
        train_slice = np.zeros(len(dataset)).astype(bool)
        train_slice[:-int(0.2 * len(dataset))] = True

    # integrate back eval values in dataset
    dataset.eval_mask[train_slice] = 0

    return dataset, train_slice


def get_imputer(imputer_name, args):
    if imputer_name == 'mean':
        imputer = MeanImputer(in_sample=args.in_sample)
    elif imputer_name == 'knn':
        imputer = SpatialKNNImputer(adj=args.adj, k=args.k)
    elif imputer_name == 'mf':
        imputer = MatrixFactorizationImputer(rank=args.rank)
    elif imputer_name == 'mice':
        imputer = MICEImputer(max_iter=args.mice_iterations,
                              n_nearest_features=args.mice_n_features,
                              in_sample=args.in_sample)
    else:
        raise ValueError(f"Imputer {imputer_name} not available in this setting.")
    return imputer


def run(imputer, dataset, train_slice):
    test_slice = ~train_slice
    if args.in_sample:
        x_train, mask_train = dataset.numpy(), dataset.training_mask
        y_hat = imputer.predict(x_train, mask_train)[test_slice]
    else:
        x_train, mask_train = dataset.numpy()[train_slice], dataset.training_mask[train_slice]
        imputer.fit(x_train, mask_train)
        x_test, mask_test = dataset.numpy()[test_slice], dataset.training_mask[test_slice]
        y_hat = imputer.predict(x_test, mask_test)

    # Evaluate model
    y_true = dataset.numpy()[test_slice]
    eval_mask = dataset.eval_mask[test_slice]

    for metric, metric_fn in metrics.items():
        error = metric_fn(y_hat, y_true, eval_mask)
        print(f'{imputer.name} on {ds_name} {metric}: {error:.4f}')


if __name__ == '__main__':

    args = parse_args()
    print(args.__dict__)

    for ds_name in args.datasets:

        dataset, train_slice = get_dataset(ds_name)
        args.adj = dataset.get_similarity(thr=0.1)

        # Instantiate imputers
        imputers = [get_imputer(name, args) for name in args.imputers]

        for imputer in imputers:
            n_runs = 1 if imputer.is_deterministic else args.n_runs
            for _ in range(n_runs):
                run(imputer, dataset, train_slice)


================================================
FILE: scripts/run_imputation.py
================================================
import copy
import datetime
import os
import pathlib
from argparse import ArgumentParser

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import yaml
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.optim.lr_scheduler import CosineAnnealingLR

from lib import fillers, datasets, config
from lib.data.datamodule import SpatioTemporalDataModule
from lib.data.imputation_dataset import ImputationDataset, GraphImputationDataset
from lib.nn import models
from lib.nn.utils.metric_base import MaskedMetric
from lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE
from lib.utils import parser_utils, numpy_metrics, ensure_list, prediction_dataframe
from lib.utils.parser_utils import str_to_bool


def has_graph_support(model_cls):
    return model_cls in [models.GRINet, models.MPGRUNet, models.BiMPGRUNet]


def get_model_classes(model_str):
    if model_str == 'brits':
        model, filler = models.BRITSNet, fillers.BRITSFiller
    elif model_str == 'grin':
        model, filler = models.GRINet, fillers.GraphFiller
    elif model_str == 'mpgru':
        model, filler = models.MPGRUNet, fillers.GraphFiller
    elif model_str == 'bimpgru':
        model, filler = models.BiMPGRUNet, fillers.GraphFiller
    elif model_str == 'var':
        model, filler = models.VARImputer, fillers.Filler
    elif model_str == 'gain':
        model, filler = models.RGAINNet, fillers.RGAINFiller
    elif model_str == 'birnn':
        model, filler = models.BiRNNImputer, fillers.MultiImputationFiller
    elif model_str == 'rnn':
        model, filler = models.RNNImputer, fillers.Filler
    else:
        raise ValueError(f'Model {model_str} not available.')
    return model, filler


def get_dataset(dataset_name):
    if dataset_name[:3] == 'air':
        dataset = datasets.AirQ
Download .txt
gitextract__pvx4bn4/

├── .gitignore
├── README.md
├── conda_env.yml
├── config/
│   ├── bimpgru/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   └── la_point.yaml
│   ├── brits/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   ├── la_point.yaml
│   │   └── synthetic.yaml
│   ├── grin/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   ├── la_point.yaml
│   │   └── synthetic.yaml
│   ├── mpgru/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   └── la_point.yaml
│   ├── rgain/
│   │   ├── air.yaml
│   │   ├── air36.yaml
│   │   ├── bay_block.yaml
│   │   ├── bay_point.yaml
│   │   ├── irish_block.yaml
│   │   ├── irish_point.yaml
│   │   ├── la_block.yaml
│   │   └── la_point.yaml
│   └── var/
│       ├── air.yaml
│       ├── air36.yaml
│       ├── bay_block.yaml
│       ├── bay_point.yaml
│       ├── irish_block.yaml
│       ├── irish_point.yaml
│       ├── la_block.yaml
│       └── la_point.yaml
├── lib/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── datamodule/
│   │   │   ├── __init__.py
│   │   │   └── spatiotemporal.py
│   │   ├── imputation_dataset.py
│   │   ├── preprocessing/
│   │   │   ├── __init__.py
│   │   │   └── scalers.py
│   │   ├── spatiotemporal_dataset.py
│   │   └── temporal_dataset.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── air_quality.py
│   │   ├── metr_la.py
│   │   ├── pd_dataset.py
│   │   ├── pems_bay.py
│   │   └── synthetic.py
│   ├── fillers/
│   │   ├── __init__.py
│   │   ├── britsfiller.py
│   │   ├── filler.py
│   │   ├── graphfiller.py
│   │   ├── multi_imputation_filler.py
│   │   └── rgainfiller.py
│   ├── nn/
│   │   ├── __init__.py
│   │   ├── layers/
│   │   │   ├── __init__.py
│   │   │   ├── gcrnn.py
│   │   │   ├── gril.py
│   │   │   ├── imputation.py
│   │   │   ├── mpgru.py
│   │   │   ├── rits.py
│   │   │   ├── spatial_attention.py
│   │   │   └── spatial_conv.py
│   │   ├── models/
│   │   │   ├── __init__.py
│   │   │   ├── brits.py
│   │   │   ├── grin.py
│   │   │   ├── mpgru.py
│   │   │   ├── rgain.py
│   │   │   ├── rnn_imputers.py
│   │   │   └── var.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── metric_base.py
│   │       ├── metrics.py
│   │       └── ops.py
│   └── utils/
│       ├── __init__.py
│       ├── numpy_metrics.py
│       ├── parser_utils.py
│       └── utils.py
├── requirements.txt
└── scripts/
    ├── run_baselines.py
    ├── run_imputation.py
    └── run_synthetic.py
Download .txt
SYMBOL INDEX (371 symbols across 37 files)

FILE: lib/data/datamodule/spatiotemporal.py
  class SpatioTemporalDataModule (line 10) | class SpatioTemporalDataModule(pl.LightningDataModule):
    method __init__ (line 15) | def __init__(self, dataset: TemporalDataset,
    method is_spatial (line 43) | def is_spatial(self):
    method n_nodes (line 47) | def n_nodes(self):
    method d_in (line 53) | def d_in(self):
    method d_out (line 59) | def d_out(self):
    method train_slice (line 65) | def train_slice(self):
    method val_slice (line 69) | def val_slice(self):
    method test_slice (line 73) | def test_slice(self):
    method get_scaling_axes (line 76) | def get_scaling_axes(self, dim='global'):
    method get_scaler (line 93) | def get_scaler(self):
    method setup (line 101) | def setup(self, stage=None):
    method _data_loader (line 117) | def _data_loader(self, dataset, shuffle=False, batch_size=None, **kwar...
    method train_dataloader (line 125) | def train_dataloader(self, shuffle=True, batch_size=None):
    method val_dataloader (line 131) | def val_dataloader(self, shuffle=False, batch_size=None):
    method test_dataloader (line 134) | def test_dataloader(self, shuffle=False, batch_size=None):
    method add_argparse_args (line 138) | def add_argparse_args(parser, **kwargs):

FILE: lib/data/imputation_dataset.py
  class ImputationDataset (line 7) | class ImputationDataset(TemporalDataset):
    method __init__ (line 9) | def __init__(self, data,
    method get (line 37) | def get(self, item, preprocess=False):
  class GraphImputationDataset (line 43) | class GraphImputationDataset(ImputationDataset, SpatioTemporalDataset):

FILE: lib/data/preprocessing/scalers.py
  class AbstractScaler (line 5) | class AbstractScaler(ABC):
    method __init__ (line 7) | def __init__(self, **kwargs):
    method __repr__ (line 11) | def __repr__(self):
    method __call__ (line 15) | def __call__(self, *args, **kwargs):
    method params (line 18) | def params(self):
    method fit (line 22) | def fit(self, x):
    method transform (line 26) | def transform(self, x):
    method inverse_transform (line 30) | def inverse_transform(self, x):
    method fit_transform (line 33) | def fit_transform(self, x):
    method to_torch (line 37) | def to_torch(self):
  class Scaler (line 47) | class Scaler(AbstractScaler):
    method __init__ (line 48) | def __init__(self, offset=0., scale=1.):
    method params (line 53) | def params(self):
    method fit (line 56) | def fit(self, x, mask=None, keepdims=True):
    method transform (line 59) | def transform(self, x):
    method inverse_transform (line 62) | def inverse_transform(self, x):
    method fit_transform (line 65) | def fit_transform(self, x, mask=None, keepdims=True):
  class StandardScaler (line 70) | class StandardScaler(Scaler):
    method __init__ (line 71) | def __init__(self, axis=0):
    method fit (line 75) | def fit(self, x, mask=None, keepdims=True):
  class MinMaxScaler (line 86) | class MinMaxScaler(Scaler):
    method __init__ (line 87) | def __init__(self, axis=0):
    method fit (line 91) | def fit(self, x, mask=None, keepdims=True):

FILE: lib/data/spatiotemporal_dataset.py
  class SpatioTemporalDataset (line 8) | class SpatioTemporalDataset(TemporalDataset):
    method __init__ (line 9) | def __init__(self, data,
    method __repr__ (line 54) | def __repr__(self):
    method n_nodes (line 58) | def n_nodes(self):
    method check_dim (line 62) | def check_dim(data):
    method dataframe (line 73) | def dataframe(self):

FILE: lib/data/temporal_dataset.py
  class TemporalDataset (line 11) | class TemporalDataset(Dataset):
    method __init__ (line 12) | def __init__(self, data,
    method __getitem__ (line 78) | def __getitem__(self, item):
    method __contains__ (line 81) | def __contains__(self, item):
    method __len__ (line 84) | def __len__(self):
    method __repr__ (line 87) | def __repr__(self):
    method data (line 93) | def data(self):
    method data (line 97) | def data(self, value):
    method trend (line 102) | def trend(self):
    method trend (line 106) | def trend(self, value):
    method add_exogenous (line 111) | def add_exogenous(self, obj, name, for_window=True, for_horizon=False):
    method horizon_offset (line 131) | def horizon_offset(self):
    method sample_span (line 135) | def sample_span(self):
    method preprocess (line 139) | def preprocess(self):
    method n_steps (line 143) | def n_steps(self):
    method n_channels (line 147) | def n_channels(self):
    method indices (line 151) | def indices(self):
    method exo_window_keys (line 157) | def exo_window_keys(self):
    method exo_horizon_keys (line 161) | def exo_horizon_keys(self):
    method exo_common_keys (line 165) | def exo_common_keys(self):
    method signature (line 169) | def signature(self):
    method get (line 189) | def get(self, item, preprocess=False):
    method snapshot (line 214) | def snapshot(self, indices=None, preprocess=True):
    method expand_indices (line 231) | def expand_indices(self, indices=None, unique=False, merge=False):
    method overlapping_indices (line 247) | def overlapping_indices(self, idxs1, idxs2, synch_mode='window', as_ma...
    method data_timestamps (line 259) | def data_timestamps(self, indices=None, flatten=True):
    method reduce_dataset (line 266) | def reduce_dataset(self, indices, inplace=False):
    method check_input (line 286) | def check_input(self, data):
    method check_dim (line 301) | def check_dim(data):
    method dataframe (line 308) | def dataframe(self):
    method add_argparse_args (line 312) | def add_argparse_args(parser, **kwargs):

FILE: lib/datasets/air_quality.py
  class AirQuality (line 11) | class AirQuality(PandasDataset):
    method __init__ (line 14) | def __init__(self, impute_nans=False, small=False, freq='60T', masked_...
    method load_raw (line 27) | def load_raw(self, small=False):
    method load (line 38) | def load(self, impute_nans=True, small=False, masked_sensors=None):
    method splitter (line 58) | def splitter(self, dataset, val_len=1., in_sample=False, window=0):
    method get_similarity (line 80) | def get_similarity(self, thr=0.1, include_self=False, force_symmetric=...
    method mask (line 93) | def mask(self):
    method training_mask (line 97) | def training_mask(self):
    method test_interval_mask (line 100) | def test_interval_mask(self, dtype=bool, squeeze=True):

FILE: lib/datasets/metr_la.py
  class MetrLA (line 11) | class MetrLA(PandasDataset):
    method __init__ (line 12) | def __init__(self, impute_zeros=False, freq='5T'):
    method load (line 18) | def load(self, impute_zeros=True):
    method load_distance_matrix (line 33) | def load_distance_matrix(self):
    method get_similarity (line 54) | def get_similarity(self, thr=0.1, force_symmetric=False, sparse=False):
    method mask (line 68) | def mask(self):
  class MissingValuesMetrLA (line 72) | class MissingValuesMetrLA(MetrLA):
    method __init__ (line 75) | def __init__(self, p_fault=0.0015, p_noise=0.05):
    method training_mask (line 89) | def training_mask(self):
    method splitter (line 92) | def splitter(self, dataset, val_len=0, test_len=0, window=0):

FILE: lib/datasets/pd_dataset.py
  class PandasDataset (line 6) | class PandasDataset:
    method __init__ (line 7) | def __init__(self, dataframe: pd.DataFrame, u: pd.DataFrame = None, na...
    method __repr__ (line 52) | def __repr__(self):
    method has_mask (line 56) | def has_mask(self):
    method has_u (line 60) | def has_u(self):
    method resample_ (line 63) | def resample_(self, freq, aggr):
    method dataframe (line 84) | def dataframe(self) -> pd.DataFrame:
    method length (line 88) | def length(self):
    method n_nodes (line 92) | def n_nodes(self):
    method mask (line 96) | def mask(self):
    method numpy (line 101) | def numpy(self, return_idx=False):
    method pytorch (line 106) | def pytorch(self):
    method __len__ (line 110) | def __len__(self):
    method build (line 114) | def build():
    method load_raw (line 117) | def load_raw(self):
    method load (line 120) | def load(self):

FILE: lib/datasets/pems_bay.py
  class PemsBay (line 11) | class PemsBay(PandasDataset):
    method __init__ (line 12) | def __init__(self):
    method load (line 17) | def load(self, impute_zeros=True):
    method load_distance_matrix (line 28) | def load_distance_matrix(self, ids):
    method get_similarity (line 47) | def get_similarity(self, type='dcrnn', thr=0.1, force_symmetric=False,...
    method mask (line 76) | def mask(self):
  class MissingValuesPemsBay (line 82) | class MissingValuesPemsBay(PemsBay):
    method __init__ (line 85) | def __init__(self, p_fault=0.0015, p_noise=0.05):
    method training_mask (line 99) | def training_mask(self):
    method splitter (line 102) | def splitter(self, dataset, val_len=0, test_len=0, window=0):

FILE: lib/datasets/synthetic.py
  function generate_mask (line 11) | def generate_mask(shape, p_block=0.01, p_point=0.01, max_seq=1, min_seq=...
  class SyntheticDataset (line 42) | class SyntheticDataset(Dataset):
    method __init__ (line 45) | def __init__(self, filename,
    method __len__ (line 87) | def __len__(self):
    method __getitem__ (line 90) | def __getitem__(self, index):
    method n_channels (line 106) | def n_channels(self):
    method n_nodes (line 110) | def n_nodes(self):
    method n_exogenous (line 114) | def n_exogenous(self):
    method training_mask (line 118) | def training_mask(self):
    method load (line 122) | def load(filename):
    method get_similarity (line 125) | def get_similarity(self, sparse=False):
    method split (line 130) | def split(self, val_len=0, test_len=0):
    method train_dataloader (line 143) | def train_dataloader(self, shuffle=True, batch_size=32):
    method val_dataloader (line 146) | def val_dataloader(self, shuffle=False, batch_size=32):
    method test_dataloader (line 149) | def test_dataloader(self, shuffle=False, batch_size=32):
  class ChargedParticles (line 153) | class ChargedParticles(SyntheticDataset):
    method __init__ (line 155) | def __init__(self, static_adj=False,
    method __getitem__ (line 180) | def __getitem__(self, item):
    method get_similarity (line 191) | def get_similarity(self, sparse=False):
    method n_exogenous (line 195) | def n_exogenous(self):

FILE: lib/fillers/britsfiller.py
  class BRITSFiller (line 7) | class BRITSFiller(Filler):
    method training_step (line 9) | def training_step(self, batch, batch_idx):
    method validation_step (line 40) | def validation_step(self, batch, batch_idx):
    method test_step (line 67) | def test_step(self, batch, batch_idx):

FILE: lib/fillers/filler.py
  class Filler (line 15) | class Filler(pl.LightningModule):
    method __init__ (line 16) | def __init__(self,
    method reset_model (line 70) | def reset_model(self):
    method trainable_parameters (line 74) | def trainable_parameters(self):
    method forward (line 78) | def forward(self, *args, **kwargs):
    method _check_metric (line 82) | def _check_metric(metric, on_step=False):
    method _set_metrics (line 91) | def _set_metrics(self, metrics):
    method _preprocess (line 97) | def _preprocess(self, data, batch_preprocessing):
    method _postprocess (line 112) | def _postprocess(self, data, batch_preprocessing):
    method predict_batch (line 127) | def predict_batch(self, batch, preprocess=False, postprocess=True, ret...
    method predict_loader (line 157) | def predict_loader(self, loader, preprocess=False, postprocess=True, r...
    method _unpack_batch (line 191) | def _unpack_batch(self, batch):
    method training_step (line 205) | def training_step(self, batch, batch_idx):
    method validation_step (line 236) | def validation_step(self, batch, batch_idx):
    method test_step (line 263) | def test_step(self, batch, batch_idx):
    method on_train_epoch_start (line 280) | def on_train_epoch_start(self) -> None:
    method configure_optimizers (line 286) | def configure_optimizers(self):

FILE: lib/fillers/graphfiller.py
  class GraphFiller (line 7) | class GraphFiller(Filler):
    method __init__ (line 9) | def __init__(self,
    method trim_seq (line 39) | def trim_seq(self, *seq):
    method training_step (line 45) | def training_step(self, batch, batch_idx):
    method validation_step (line 85) | def validation_step(self, batch, batch_idx):
    method test_step (line 116) | def test_step(self, batch, batch_idx):

FILE: lib/fillers/multi_imputation_filler.py
  class MultiImputationFiller (line 7) | class MultiImputationFiller(Filler):
    method __init__ (line 12) | def __init__(self,
    method forward (line 38) | def forward(self, *args, **kwargs):
    method _consistency_loss (line 45) | def _consistency_loss(self, imputations, mask):
    method training_step (line 49) | def training_step(self, batch, batch_idx):

FILE: lib/fillers/rgainfiller.py
  class MaskedBCEWithLogits (line 8) | class MaskedBCEWithLogits(MaskedMetric):
    method __init__ (line 9) | def __init__(self,
  class RGAINFiller (line 28) | class RGAINFiller(MultiImputationFiller):
    method __init__ (line 29) | def __init__(self,
    method training_step (line 65) | def training_step(self, batch, batch_idx):
    method configure_optimizers (line 137) | def configure_optimizers(self):

FILE: lib/nn/layers/gcrnn.py
  class GCGRUCell (line 7) | class GCGRUCell(nn.Module):
    method __init__ (line 12) | def __init__(self, d_in, num_units, support_len, order, activation='ta...
    method forward (line 28) | def forward(self, x, h, adj):
  class GCRNN (line 45) | class GCRNN(nn.Module):
    method __init__ (line 46) | def __init__(self,
    method init_hidden_states (line 66) | def init_hidden_states(self, x):
    method single_pass (line 69) | def single_pass(self, x, h, adj):
    method forward (line 75) | def forward(self, x, adj, h=None):

FILE: lib/nn/layers/gril.py
  class SpatialDecoder (line 11) | class SpatialDecoder(nn.Module):
    method __init__ (line 12) | def __init__(self, d_in, d_model, d_out, support_len, order=1, attenti...
    method forward (line 31) | def forward(self, x, m, h, u, adj, cached_support=False):
  class GRIL (line 57) | class GRIL(nn.Module):
    method __init__ (line 58) | def __init__(self,
    method init_hidden_states (line 106) | def init_hidden_states(self, n_nodes):
    method get_h0 (line 114) | def get_h0(self, x):
    method update_state (line 119) | def update_state(self, x, h, adj):
    method forward (line 127) | def forward(self, x, adj, mask=None, u=None, h=None, cached_support=Fa...
  class BiGRIL (line 181) | class BiGRIL(nn.Module):
    method __init__ (line 182) | def __init__(self,
    method forward (line 246) | def forward(self, x, adj, mask=None, u=None, cached_support=False):

FILE: lib/nn/layers/imputation.py
  class ImputationLayer (line 8) | class ImputationLayer(nn.Module):
    method __init__ (line 9) | def __init__(self, d_in, bias=True):
    method reset_parameters (line 20) | def reset_parameters(self):
    method forward (line 27) | def forward(self, x):

FILE: lib/nn/layers/mpgru.py
  class MPGRUImputer (line 7) | class MPGRUImputer(nn.Module):
    method __init__ (line 8) | def __init__(self,
    method init_hidden_states (line 58) | def init_hidden_states(self, n_nodes):
    method get_h0 (line 66) | def get_h0(self, x):
    method update_state (line 71) | def update_state(self, x, h, adj):
    method forward (line 79) | def forward(self, x, adj, mask=None, u=None, h=None):

FILE: lib/nn/layers/rits.py
  class FeatureRegression (line 12) | class FeatureRegression(nn.Module):
    method __init__ (line 13) | def __init__(self, input_size):
    method reset_parameters (line 23) | def reset_parameters(self):
    method forward (line 29) | def forward(self, x):
  class TemporalDecay (line 34) | class TemporalDecay(nn.Module):
    method __init__ (line 35) | def __init__(self, d_in, d_out, diag=False):
    method reset_parameters (line 48) | def reset_parameters(self):
    method compute_delta (line 55) | def compute_delta(mask, freq=1):
    method forward (line 63) | def forward(self, d):
  class RITS (line 72) | class RITS(nn.Module):
    method __init__ (line 73) | def __init__(self,
    method init_hidden_states (line 90) | def init_hidden_states(self, x):
    method forward (line 93) | def forward(self, x, mask=None, delta=None):
  class BRITS (line 145) | class BRITS(nn.Module):
    method __init__ (line 147) | def __init__(self, input_size, hidden_size):
    method forward (line 152) | def forward(self, x, mask=None):
    method consistency_loss (line 169) | def consistency_loss(imp_fwd, imp_bwd):

FILE: lib/nn/layers/spatial_attention.py
  class SpatialAttention (line 6) | class SpatialAttention(nn.Module):
    method __init__ (line 7) | def __init__(self, d_in, d_model, nheads, dropout=0.):
    method forward (line 12) | def forward(self, x, att_mask=None, **kwargs):

FILE: lib/nn/layers/spatial_conv.py
  class SpatialConvOrderK (line 7) | class SpatialConvOrderK(nn.Module):
    method __init__ (line 14) | def __init__(self, c_in, c_out, support_len=3, order=2, include_self=T...
    method compute_support (line 22) | def compute_support(adj, device=None):
    method compute_support_orderK (line 32) | def compute_support_orderK(adj, k, include_self=False, device=None):
    method forward (line 47) | def forward(self, x, support):

FILE: lib/nn/models/brits.py
  class BRITSNet (line 7) | class BRITSNet(nn.Module):
    method __init__ (line 8) | def __init__(self,
    method forward (line 15) | def forward(self, x, mask=None, **kwargs):
    method add_model_specific_args (line 27) | def add_model_specific_args(parser):

FILE: lib/nn/models/grin.py
  class GRINet (line 9) | class GRINet(nn.Module):
    method __init__ (line 10) | def __init__(self,
    method forward (line 47) | def forward(self, x, mask=None, u=None, **kwargs):
    method add_model_specific_args (line 69) | def add_model_specific_args(parser):

FILE: lib/nn/models/mpgru.py
  class MPGRUNet (line 10) | class MPGRUNet(nn.Module):
    method __init__ (line 11) | def __init__(self,
    method forward (line 38) | def forward(self, x, mask=None, u=None, h=None):
    method add_model_specific_args (line 59) | def add_model_specific_args(parser):
  class BiMPGRUNet (line 70) | class BiMPGRUNet(nn.Module):
    method __init__ (line 71) | def __init__(self,
    method forward (line 133) | def forward(self, x, mask=None, u=None, h=None):
    method add_model_specific_args (line 171) | def add_model_specific_args(parser):

FILE: lib/nn/models/rgain.py
  class Generator (line 8) | class Generator(nn.Module):
    method __init__ (line 9) | def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=True):
    method forward (line 21) | def forward(self, x, mask):
  class Discriminator (line 29) | class Discriminator(torch.nn.Module):
    method __init__ (line 30) | def __init__(self, d_in, d_model, dropout=0.):
    method forward (line 36) | def forward(self, x, h):
  class RGAINNet (line 43) | class RGAINNet(torch.nn.Module):
    method __init__ (line 44) | def __init__(self, d_in, d_model, d_z, dropout=0., inject_noise=False,...
    method forward (line 51) | def forward(self, x, mask, **kwargs):
    method add_model_specific_args (line 61) | def add_model_specific_args(parser):

FILE: lib/nn/models/rnn_imputers.py
  class RNNImputer (line 7) | class RNNImputer(nn.Module):
    method __init__ (line 10) | def __init__(self, d_in, d_model, concat_mask=True, detach_inputs=Fals...
    method init_hidden_state (line 20) | def init_hidden_state(self, x):
    method _preprocess_input (line 26) | def _preprocess_input(self, x, x_hat, m, u):
    method forward (line 38) | def forward(self, x, mask, u=None, return_hidden=False):
    method add_model_specific_args (line 63) | def add_model_specific_args(parser):
  class BiRNNImputer (line 69) | class BiRNNImputer(nn.Module):
    method __init__ (line 72) | def __init__(self, d_in, d_model, dropout=0., concat_mask=True, detach...
    method forward (line 82) | def forward(self, x, mask, u=None, return_hidden=False):
    method add_model_specific_args (line 98) | def add_model_specific_args(parser):

FILE: lib/nn/models/var.py
  class VAR (line 8) | class VAR(nn.Module):
    method __init__ (line 9) | def __init__(self, order, d_in, d_out=None, steps_ahead=1, bias=True):
    method forward (line 17) | def forward(self, x):
    method add_model_specific_args (line 25) | def add_model_specific_args(parser):
  class VARImputer (line 33) | class VARImputer(nn.Module):
    method __init__ (line 36) | def __init__(self, order, d_in, padding='mean'):
    method forward (line 43) | def forward(self, x, mask=None):
    method add_model_specific_args (line 66) | def add_model_specific_args(parser):

FILE: lib/nn/utils/metric_base.py
  class MaskedMetric (line 8) | class MaskedMetric(Metric):
    method __init__ (line 9) | def __init__(self,
    method _check_mask (line 36) | def _check_mask(self, mask, val):
    method _compute_masked (line 47) | def _compute_masked(self, y_hat, y, mask):
    method _compute_std (line 54) | def _compute_std(self, y_hat, y):
    method is_masked (line 59) | def is_masked(self, mask):
    method update (line 62) | def update(self, y_hat, y, mask=None):
    method compute (line 74) | def compute(self):

FILE: lib/nn/utils/metrics.py
  class MaskedMAE (line 11) | class MaskedMAE(MaskedMetric):
    method __init__ (line 12) | def __init__(self,
  class MaskedMAPE (line 31) | class MaskedMAPE(MaskedMetric):
    method __init__ (line 32) | def __init__(self,
  class MaskedMSE (line 49) | class MaskedMSE(MaskedMetric):
    method __init__ (line 50) | def __init__(self,
  class MaskedMRE (line 68) | class MaskedMRE(MaskedMetric):
    method __init__ (line 69) | def __init__(self,
    method _compute_masked (line 88) | def _compute_masked(self, y_hat, y, mask):
    method _compute_std (line 96) | def _compute_std(self, y_hat, y):
    method compute (line 101) | def compute(self):
    method update (line 106) | def update(self, y_hat, y, mask=None):

FILE: lib/nn/utils/ops.py
  function mae (line 9) | def mae(y_hat, y, reduction='none'):
  function mape (line 13) | def mape(y_hat, y):
  function wape_loss (line 17) | def wape_loss(y_hat, y):
  function smape_loss (line 22) | def smape_loss(y_hat, y):
  function peak_prediction_loss (line 30) | def peak_prediction_loss(y_hat, y, reduction='none'):
  function wrap_loss_fn (line 37) | def wrap_loss_fn(base_loss):
  function rbf_sim (line 55) | def rbf_sim(x, gamma, device='cpu'):
  function reverse_tensor (line 64) | def reverse_tensor(tensor=None, axis=-1):

FILE: lib/utils/numpy_metrics.py
  function mae (line 4) | def mae(y_hat, y):
  function nmae (line 8) | def nmae(y_hat, y):
  function mape (line 13) | def mape(y_hat, y):
  function mse (line 17) | def mse(y_hat, y):
  function rmse (line 21) | def rmse(y_hat, y):
  function nrmse (line 25) | def nrmse(y_hat, y):
  function nrmse_2 (line 30) | def nrmse_2(y_hat, y):
  function r2 (line 35) | def r2(y_hat, y):
  function masked_mae (line 39) | def masked_mae(y_hat, y, mask):
  function masked_mape (line 44) | def masked_mape(y_hat, y, mask):
  function masked_mse (line 49) | def masked_mse(y_hat, y, mask):
  function masked_rmse (line 54) | def masked_rmse(y_hat, y, mask):
  function masked_mre (line 59) | def masked_mre(y_hat, y, mask):

FILE: lib/utils/parser_utils.py
  function str_to_bool (line 6) | def str_to_bool(value):
  function config_dict_from_args (line 16) | def config_dict_from_args(args):
  function update_from_config (line 29) | def update_from_config(args: Namespace, config: dict):
  function parse_by_group (line 35) | def parse_by_group(parser):
  function filter_args (line 68) | def filter_args(args: Union[Namespace, dict], target_cls, return_dict=Fa...
  function filter_function_args (line 79) | def filter_function_args(args: Union[Namespace, dict], function, return_...

FILE: lib/utils/utils.py
  function sample_mask (line 7) | def sample_mask(shape, p=0.002, p_noise=0., max_seq=1, min_seq=1, rng=No...
  function compute_mean (line 30) | def compute_mean(x, index=None):
  function geographical_distance (line 56) | def geographical_distance(x=None, to_rad=True):
  function infer_mask (line 94) | def infer_mask(df, infer_from='next'):
  function prediction_dataframe (line 125) | def prediction_dataframe(y, index, columns=None, aggregate_by='mean'):
  function ensure_list (line 161) | def ensure_list(obj):
  function missing_val_lens (line 168) | def missing_val_lens(mask):
  function disjoint_months (line 181) | def disjoint_months(dataset, months=None, synch_mode='window'):
  function thresholded_gaussian_kernel (line 205) | def thresholded_gaussian_kernel(x, theta=None, threshold=None, threshold...

FILE: scripts/run_baselines.py
  function parse_args (line 19) | def parse_args():
  class Imputer (line 45) | class Imputer:
    method __init__ (line 48) | def __init__(self, method=None, is_deterministic=True, in_sample=True):
    method fit (line 54) | def fit(self, x, mask):
    method predict (line 59) | def predict(self, x, mask):
    method params (line 66) | def params(self):
  class SpatialKNNImputer (line 70) | class SpatialKNNImputer(Imputer):
    method __init__ (line 73) | def __init__(self, adj, k=20):
    method fit (line 84) | def fit(self, x, mask):
    method predict (line 87) | def predict(self, x, mask):
    method params (line 94) | def params(self):
  class MeanImputer (line 98) | class MeanImputer(Imputer):
    method fit (line 101) | def fit(self, x, mask):
    method predict (line 105) | def predict(self, x, mask):
  class MatrixFactorizationImputer (line 114) | class MatrixFactorizationImputer(Imputer):
    method __init__ (line 117) | def __init__(self, rank=10, loss='mae', verbose=0):
    method params (line 121) | def params(self):
  class MICEImputer (line 125) | class MICEImputer(Imputer):
    method __init__ (line 128) | def __init__(self, max_iter=100, n_nearest_features=None, in_sample=Tr...
    method params (line 133) | def params(self):
  function get_dataset (line 137) | def get_dataset(dataset_name):
  function get_imputer (line 165) | def get_imputer(imputer_name, args):
  function run (line 181) | def run(imputer, dataset, train_slice):

FILE: scripts/run_imputation.py
  function has_graph_support (line 26) | def has_graph_support(model_cls):
  function get_model_classes (line 30) | def get_model_classes(model_str):
  function get_dataset (line 52) | def get_dataset(dataset_name):
  function parse_args (line 68) | def parse_args():
  function run_experiment (line 118) | def run_experiment(args):

FILE: scripts/run_synthetic.py
  function has_graph_support (line 25) | def has_graph_support(model_cls):
  function get_model_classes (line 29) | def get_model_classes(model_str):
  function parse_args (line 39) | def parse_args():
  function run_experiment (line 86) | def run_experiment(args):
Condensed preview — 102 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (218K chars).
[
  {
    "path": ".gitignore",
    "chars": 11,
    "preview": "*.DS_STORE\n"
  },
  {
    "path": "README.md",
    "chars": 4627,
    "preview": "# Filling the G_ap_s: Multivariate Time Series Imputation by Graph Neural Networks (ICLR 2022 - [open review](https://op"
  },
  {
    "path": "conda_env.yml",
    "chars": 364,
    "preview": "name: grin\nchannels:\n  - defaults\n  - pytorch\n  - conda-forge\ndependencies:\n  - pip\n  - pytables\n  - python=3.8\n  - pyto"
  },
  {
    "path": "config/bimpgru/air.yaml",
    "chars": 364,
    "preview": "dataset_name: 'air'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'g"
  },
  {
    "path": "config/bimpgru/air36.yaml",
    "chars": 366,
    "preview": "dataset_name: 'air36'\nwindow: 36\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', "
  },
  {
    "path": "config/bimpgru/bay_block.yaml",
    "chars": 370,
    "preview": "dataset_name: 'bay_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channel"
  },
  {
    "path": "config/bimpgru/bay_point.yaml",
    "chars": 370,
    "preview": "dataset_name: 'bay_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channel"
  },
  {
    "path": "config/bimpgru/irish_block.yaml",
    "chars": 372,
    "preview": "dataset_name: 'irish_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['chann"
  },
  {
    "path": "config/bimpgru/irish_point.yaml",
    "chars": 372,
    "preview": "dataset_name: 'irish_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['chann"
  },
  {
    "path": "config/bimpgru/la_block.yaml",
    "chars": 369,
    "preview": "dataset_name: 'la_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels"
  },
  {
    "path": "config/bimpgru/la_point.yaml",
    "chars": 369,
    "preview": "dataset_name: 'la_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels"
  },
  {
    "path": "config/brits/air.yaml",
    "chars": 232,
    "preview": "dataset_name: 'air'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsam"
  },
  {
    "path": "config/brits/air36.yaml",
    "chars": 233,
    "preview": "dataset_name: 'air36'\nwindow: 36\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\ns"
  },
  {
    "path": "config/brits/bay_block.yaml",
    "chars": 238,
    "preview": "dataset_name: 'bay_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 3"
  },
  {
    "path": "config/brits/bay_point.yaml",
    "chars": 238,
    "preview": "dataset_name: 'bay_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 3"
  },
  {
    "path": "config/brits/irish_block.yaml",
    "chars": 240,
    "preview": "dataset_name: 'irish_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs:"
  },
  {
    "path": "config/brits/irish_point.yaml",
    "chars": 240,
    "preview": "dataset_name: 'irish_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs:"
  },
  {
    "path": "config/brits/la_block.yaml",
    "chars": 237,
    "preview": "dataset_name: 'la_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 30"
  },
  {
    "path": "config/brits/la_point.yaml",
    "chars": 237,
    "preview": "dataset_name: 'la_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 30"
  },
  {
    "path": "config/brits/synthetic.yaml",
    "chars": 145,
    "preview": "window: 36\np_block: 0.025\np_point: 0.025\nmin_seq: 4\nmax_seq: 9\nuse_exogenous: False\n\nepochs: 200\nbatch_size: 32\n\nmodel_n"
  },
  {
    "path": "config/grin/air.yaml",
    "chars": 403,
    "preview": "dataset_name: 'air'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'g"
  },
  {
    "path": "config/grin/air36.yaml",
    "chars": 405,
    "preview": "dataset_name: 'air36'\nwindow: 36\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', "
  },
  {
    "path": "config/grin/bay_block.yaml",
    "chars": 409,
    "preview": "dataset_name: 'bay_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channel"
  },
  {
    "path": "config/grin/bay_point.yaml",
    "chars": 409,
    "preview": "dataset_name: 'bay_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channel"
  },
  {
    "path": "config/grin/irish_block.yaml",
    "chars": 411,
    "preview": "dataset_name: 'irish_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['chann"
  },
  {
    "path": "config/grin/irish_point.yaml",
    "chars": 411,
    "preview": "dataset_name: 'irish_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['chann"
  },
  {
    "path": "config/grin/la_block.yaml",
    "chars": 408,
    "preview": "dataset_name: 'la_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels"
  },
  {
    "path": "config/grin/la_point.yaml",
    "chars": 408,
    "preview": "dataset_name: 'la_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels"
  },
  {
    "path": "config/grin/synthetic.yaml",
    "chars": 252,
    "preview": "window: 36\np_block: 0.025\np_point: 0.025\nmin_seq: 4\nmax_seq: 9\nuse_exogenous: False\n\nepochs: 200\nbatch_size: 32\n\nmodel_n"
  },
  {
    "path": "config/mpgru/air.yaml",
    "chars": 360,
    "preview": "dataset_name: 'air'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', 'g"
  },
  {
    "path": "config/mpgru/air36.yaml",
    "chars": 362,
    "preview": "dataset_name: 'air36'\nwindow: 36\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels', "
  },
  {
    "path": "config/mpgru/bay_block.yaml",
    "chars": 366,
    "preview": "dataset_name: 'bay_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channel"
  },
  {
    "path": "config/mpgru/bay_point.yaml",
    "chars": 366,
    "preview": "dataset_name: 'bay_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channel"
  },
  {
    "path": "config/mpgru/irish_block.yaml",
    "chars": 368,
    "preview": "dataset_name: 'irish_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['chann"
  },
  {
    "path": "config/mpgru/irish_point.yaml",
    "chars": 368,
    "preview": "dataset_name: 'irish_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['chann"
  },
  {
    "path": "config/mpgru/la_block.yaml",
    "chars": 365,
    "preview": "dataset_name: 'la_block'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels"
  },
  {
    "path": "config/mpgru/la_point.yaml",
    "chars": 365,
    "preview": "dataset_name: 'la_point'\nwindow: 24\n\nadj_threshold: 0.1\n\ndetrend: False\nscale: True\nscaling_axis: 'global'  # ['channels"
  },
  {
    "path": "config/rgain/air.yaml",
    "chars": 410,
    "preview": "dataset_name: 'air'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True"
  },
  {
    "path": "config/rgain/air36.yaml",
    "chars": 411,
    "preview": "dataset_name: 'air36'\nwindow: 36\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: Tr"
  },
  {
    "path": "config/rgain/bay_block.yaml",
    "chars": 416,
    "preview": "dataset_name: 'bay_block'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target"
  },
  {
    "path": "config/rgain/bay_point.yaml",
    "chars": 416,
    "preview": "dataset_name: 'bay_point'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target"
  },
  {
    "path": "config/rgain/irish_block.yaml",
    "chars": 418,
    "preview": "dataset_name: 'irish_block'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_targ"
  },
  {
    "path": "config/rgain/irish_point.yaml",
    "chars": 418,
    "preview": "dataset_name: 'irish_point'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_targ"
  },
  {
    "path": "config/rgain/la_block.yaml",
    "chars": 415,
    "preview": "dataset_name: 'la_block'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target:"
  },
  {
    "path": "config/rgain/la_point.yaml",
    "chars": 415,
    "preview": "dataset_name: 'la_point'\nwindow: 24\nwhiten_prob: 0.2\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target:"
  },
  {
    "path": "config/var/air.yaml",
    "chars": 252,
    "preview": "dataset_name: 'air'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\nsam"
  },
  {
    "path": "config/var/air36.yaml",
    "chars": 254,
    "preview": "dataset_name: 'air36'\nwindow: 36\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 300\ns"
  },
  {
    "path": "config/var/bay_block.yaml",
    "chars": 258,
    "preview": "dataset_name: 'bay_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 3"
  },
  {
    "path": "config/var/bay_point.yaml",
    "chars": 258,
    "preview": "dataset_name: 'bay_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 3"
  },
  {
    "path": "config/var/irish_block.yaml",
    "chars": 260,
    "preview": "dataset_name: 'irish_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs:"
  },
  {
    "path": "config/var/irish_point.yaml",
    "chars": 260,
    "preview": "dataset_name: 'irish_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs:"
  },
  {
    "path": "config/var/la_block.yaml",
    "chars": 257,
    "preview": "dataset_name: 'la_block'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 30"
  },
  {
    "path": "config/var/la_point.yaml",
    "chars": 257,
    "preview": "dataset_name: 'la_point'\nwindow: 24\n\ndetrend: False\nscale: True\nscaling_axis: 'channels'\nscaled_target: True\n\nepochs: 30"
  },
  {
    "path": "lib/__init__.py",
    "chars": 442,
    "preview": "import os\n\nbase_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))\n\nconfig = {\n    'logs': 'logs/'\n}\ndat"
  },
  {
    "path": "lib/data/__init__.py",
    "chars": 104,
    "preview": "from .temporal_dataset import TemporalDataset\nfrom .spatiotemporal_dataset import SpatioTemporalDataset\n"
  },
  {
    "path": "lib/data/datamodule/__init__.py",
    "chars": 53,
    "preview": "from .spatiotemporal import SpatioTemporalDataModule\n"
  },
  {
    "path": "lib/data/datamodule/spatiotemporal.py",
    "chars": 5840,
    "preview": "import pytorch_lightning as pl\nfrom torch.utils.data import DataLoader, Subset, RandomSampler\n\nfrom .. import TemporalDa"
  },
  {
    "path": "lib/data/imputation_dataset.py",
    "chars": 1615,
    "preview": "import numpy as np\nimport torch\n\nfrom . import TemporalDataset, SpatioTemporalDataset\n\n\nclass ImputationDataset(Temporal"
  },
  {
    "path": "lib/data/preprocessing/__init__.py",
    "chars": 23,
    "preview": "from .scalers import *\n"
  },
  {
    "path": "lib/data/preprocessing/scalers.py",
    "chars": 2851,
    "preview": "from abc import ABC, abstractmethod\nimport numpy as np\n\n\nclass AbstractScaler(ABC):\n\n    def __init__(self, **kwargs):\n "
  },
  {
    "path": "lib/data/spatiotemporal_dataset.py",
    "chars": 3006,
    "preview": "import numpy as np\nimport pandas as pd\nfrom einops import rearrange\n\nfrom .temporal_dataset import TemporalDataset\n\n\ncla"
  },
  {
    "path": "lib/data/temporal_dataset.py",
    "chars": 12077,
    "preview": "import numpy as np\nimport pandas as pd\nimport torch\nfrom einops import rearrange\nfrom pandas import DatetimeIndex\nfrom t"
  },
  {
    "path": "lib/datasets/__init__.py",
    "chars": 160,
    "preview": "from .air_quality import AirQuality\nfrom .metr_la import MissingValuesMetrLA\nfrom .pems_bay import MissingValuesPemsBay\n"
  },
  {
    "path": "lib/datasets/air_quality.py",
    "chars": 4732,
    "preview": "import os\n\nimport numpy as np\nimport pandas as pd\n\nfrom lib import datasets_path\nfrom .pd_dataset import PandasDataset\nf"
  },
  {
    "path": "lib/datasets/metr_la.py",
    "chars": 3757,
    "preview": "import os\n\nimport numpy as np\nimport pandas as pd\n\nfrom lib import datasets_path\nfrom .pd_dataset import PandasDataset\nf"
  },
  {
    "path": "lib/datasets/pd_dataset.py",
    "chars": 3435,
    "preview": "import numpy as np\nimport pandas as pd\nimport torch\n\n\nclass PandasDataset:\n    def __init__(self, dataframe: pd.DataFram"
  },
  {
    "path": "lib/datasets/pems_bay.py",
    "chars": 4199,
    "preview": "import os\n\nimport numpy as np\nimport pandas as pd\n\nfrom lib import datasets_path\nfrom .pd_dataset import PandasDataset\nf"
  },
  {
    "path": "lib/datasets/synthetic.py",
    "chars": 7408,
    "preview": "import os.path\n\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom torch.utils.data import Dataset, DataL"
  },
  {
    "path": "lib/fillers/__init__.py",
    "chars": 197,
    "preview": "from .filler import Filler\nfrom .britsfiller import BRITSFiller\nfrom .graphfiller import GraphFiller\nfrom .rgainfiller i"
  },
  {
    "path": "lib/fillers/britsfiller.py",
    "chars": 3431,
    "preview": "import torch\n\nfrom . import Filler\nfrom ..nn import BRITS\n\n\nclass BRITSFiller(Filler):\n\n    def training_step(self, batc"
  },
  {
    "path": "lib/fillers/filler.py",
    "chars": 12351,
    "preview": "import inspect\nfrom copy import deepcopy\n\nimport pytorch_lightning as pl\nimport torch\nfrom pytorch_lightning.core.decora"
  },
  {
    "path": "lib/fillers/graphfiller.py",
    "chars": 5454,
    "preview": "import torch\n\nfrom . import Filler\nfrom ..nn.models import MPGRUNet, GRINet, BiMPGRUNet\n\n\nclass GraphFiller(Filler):\n\n  "
  },
  {
    "path": "lib/fillers/multi_imputation_filler.py",
    "chars": 2930,
    "preview": "import torch\nfrom pytorch_lightning.core.decorators import auto_move_data\n\nfrom . import Filler\n\n\nclass MultiImputationF"
  },
  {
    "path": "lib/fillers/rgainfiller.py",
    "chars": 6340,
    "preview": "import torch\nfrom torch.nn import functional as F\n\nfrom .multi_imputation_filler import MultiImputationFiller\nfrom ..nn."
  },
  {
    "path": "lib/nn/__init__.py",
    "chars": 22,
    "preview": "from .layers import *\n"
  },
  {
    "path": "lib/nn/layers/__init__.py",
    "chars": 137,
    "preview": "from .rits import RITS, BRITS\nfrom .gril import GRIL, BiGRIL\nfrom .spatial_conv import SpatialConvOrderK\nfrom .mpgru imp"
  },
  {
    "path": "lib/nn/layers/gcrnn.py",
    "chars": 3184,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .spatial_conv import SpatialConvOrderK\n\n\nclass GCGRUCell(nn.Module):\n    \"\"\"\n  "
  },
  {
    "path": "lib/nn/layers/gril.py",
    "chars": 11916,
    "preview": "import torch\nimport torch.nn as nn\nfrom einops import rearrange\n\nfrom .spatial_conv import SpatialConvOrderK\nfrom .gcrnn"
  },
  {
    "path": "lib/nn/layers/imputation.py",
    "chars": 897,
    "preview": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass ImputationLayer(nn.Module):\n"
  },
  {
    "path": "lib/nn/layers/mpgru.py",
    "chars": 4835,
    "preview": "import torch\nfrom torch import nn\n\nfrom .gcrnn import GCGRUCell\n\n\nclass MPGRUImputer(nn.Module):\n    def __init__(self,\n"
  },
  {
    "path": "lib/nn/layers/rits.py",
    "chars": 5700,
    "preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom"
  },
  {
    "path": "lib/nn/layers/spatial_attention.py",
    "chars": 971,
    "preview": "import torch.nn as nn\n\nfrom einops import rearrange\n\n\nclass SpatialAttention(nn.Module):\n    def __init__(self, d_in, d_"
  },
  {
    "path": "lib/nn/layers/spatial_conv.py",
    "chars": 2253,
    "preview": "import torch\nfrom torch import nn\n\nfrom ... import epsilon\n\n\nclass SpatialConvOrderK(nn.Module):\n    \"\"\"\n    Spatial con"
  },
  {
    "path": "lib/nn/models/__init__.py",
    "chars": 200,
    "preview": "from .grin import GRINet\nfrom .brits import BRITSNet\nfrom .mpgru import MPGRUNet, BiMPGRUNet\nfrom .var import VARImputer"
  },
  {
    "path": "lib/nn/models/brits.py",
    "chars": 1171,
    "preview": "import torch\nfrom torch import nn\n\nfrom ..layers import BRITS\n\n\nclass BRITSNet(nn.Module):\n    def __init__(self,\n      "
  },
  {
    "path": "lib/nn/models/grin.py",
    "chars": 3589,
    "preview": "import torch\nfrom einops import rearrange\nfrom torch import nn\n\nfrom ..layers import BiGRIL\nfrom ...utils.parser_utils i"
  },
  {
    "path": "lib/nn/models/mpgru.py",
    "chars": 8025,
    "preview": "import torch\nfrom einops import rearrange\nfrom torch import nn\n\nfrom ..layers import MPGRUImputer, SpatialConvOrderK\nfro"
  },
  {
    "path": "lib/nn/models/rgain.py",
    "chars": 2564,
    "preview": "import torch\nfrom torch import nn\n\nfrom .rnn_imputers import BiRNNImputer\nfrom ...utils.parser_utils import str_to_bool\n"
  },
  {
    "path": "lib/nn/models/rnn_imputers.py",
    "chars": 3972,
    "preview": "import torch\nfrom torch import nn\n\nfrom ..utils.ops import reverse_tensor\n\n\nclass RNNImputer(nn.Module):\n    \"\"\"Fill the"
  },
  {
    "path": "lib/nn/models/var.py",
    "chars": 2612,
    "preview": "import torch\nfrom einops import rearrange\nfrom torch import nn\n\nfrom lib import epsilon\n\n\nclass VAR(nn.Module):\n    def "
  },
  {
    "path": "lib/nn/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lib/nn/utils/metric_base.py",
    "chars": 2660,
    "preview": "from functools import partial\n\nimport torch\nfrom pytorch_lightning.metrics import Metric\nfrom torchmetrics.utilities.che"
  },
  {
    "path": "lib/nn/utils/metrics.py",
    "chars": 4835,
    "preview": "from .metric_base import MaskedMetric\nfrom .ops import mape\nfrom torch.nn import functional as F\nimport torch\n\nfrom torc"
  },
  {
    "path": "lib/nn/utils/ops.py",
    "chars": 1953,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom einops import reduce\nfrom torch.autograd import Variable\n\nfrom ... imp"
  },
  {
    "path": "lib/utils/__init__.py",
    "chars": 21,
    "preview": "from .utils import *\n"
  },
  {
    "path": "lib/utils/numpy_metrics.py",
    "chars": 1256,
    "preview": "import numpy as np\n\n\ndef mae(y_hat, y):\n    return np.abs(y_hat - y).mean()\n\n\ndef nmae(y_hat, y):\n    delta = np.max(y) "
  },
  {
    "path": "lib/utils/parser_utils.py",
    "chars": 3358,
    "preview": "import inspect\nfrom argparse import Namespace, ArgumentParser\nfrom typing import Union\n\n\ndef str_to_bool(value):\n    if "
  },
  {
    "path": "lib/utils/utils.py",
    "chars": 8755,
    "preview": "import numpy as np\nimport pandas as pd\n\nfrom sklearn.metrics.pairwise import haversine_distances\n\n\ndef sample_mask(shape"
  },
  {
    "path": "requirements.txt",
    "chars": 211,
    "preview": "einops\nfancyimpute==0.6\nh5py\nopenpyxl\nnumpy\npandas\npytorch-lightning==1.4\npyyaml\nscikit-learn\nscipy\ntables\ntensorboard\nt"
  },
  {
    "path": "scripts/run_baselines.py",
    "chars": 7437,
    "preview": "from argparse import ArgumentParser\n\nimport numpy as np\nfrom fancyimpute import MatrixFactorization, IterativeImputer\nfr"
  },
  {
    "path": "scripts/run_imputation.py",
    "chars": 12130,
    "preview": "import copy\nimport datetime\nimport os\nimport pathlib\nfrom argparse import ArgumentParser\n\nimport numpy as np\nimport pyto"
  },
  {
    "path": "scripts/run_synthetic.py",
    "chars": 9402,
    "preview": "import copy\nimport datetime\nimport os\nimport pathlib\nfrom argparse import ArgumentParser\n\nimport numpy as np\nimport pyto"
  }
]

About this extraction

This page contains the full source code of the Graph-Machine-Learning-Group/grin GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 102 files (201.4 KB), approximately 52.6k tokens, and a symbol index with 371 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!