Repository: lukecavabarrett/pna Branch: master Commit: 6867d9aadcbf Files: 53 Total size: 229.4 KB Directory structure: gitextract_i5b7q8_j/ ├── LICENSE ├── README.md ├── models/ │ ├── dgl/ │ │ ├── aggregators.py │ │ ├── pna_layer.py │ │ └── scalers.py │ ├── layers.py │ ├── pytorch/ │ │ ├── gat/ │ │ │ └── layer.py │ │ ├── gcn/ │ │ │ └── layer.py │ │ ├── gin/ │ │ │ └── layer.py │ │ ├── gnn_framework.py │ │ └── pna/ │ │ ├── aggregators.py │ │ ├── layer.py │ │ └── scalers.py │ └── pytorch_geometric/ │ ├── aggregators.py │ ├── example.py │ ├── pna.py │ └── scalers.py ├── multitask_benchmark/ │ ├── README.md │ ├── datasets_generation/ │ │ ├── graph_algorithms.py │ │ ├── graph_generation.py │ │ └── multitask_dataset.py │ ├── requirements.txt │ ├── train/ │ │ ├── gat.py │ │ ├── gcn.py │ │ ├── gin.py │ │ ├── mpnn.py │ │ └── pna.py │ └── util/ │ ├── train.py │ └── util.py └── realworld_benchmark/ ├── README.md ├── configs/ │ ├── molecules_graph_classification_PNA_HIV.json │ ├── molecules_graph_regression_pna_ZINC.json │ ├── superpixels_graph_classification_pna_CIFAR10.json │ └── superpixels_graph_classification_pna_MNIST.json ├── data/ │ ├── HIV.py │ ├── download_datasets.sh │ ├── molecules.py │ └── superpixels.py ├── docs/ │ └── setup.md ├── environment_cpu.yml ├── environment_gpu.yml ├── main_HIV.py ├── main_molecules.py ├── main_superpixels.py ├── nets/ │ ├── HIV_graph_classification/ │ │ └── pna_net.py │ ├── gru.py │ ├── mlp_readout_layer.py │ ├── molecules_graph_regression/ │ │ └── pna_net.py │ └── superpixels_graph_classification/ │ └── pna_net.py └── train/ ├── metrics.py ├── train_HIV_graph_classification.py ├── train_molecules_graph_regression.py └── train_superpixels_graph_classification.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 Gabriele Corso, Luca Cavalleri Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Principal Neighbourhood Aggregation Implementation of Principal Neighbourhood Aggregation for Graph Nets [arxiv.org/abs/2004.05718](https://arxiv.org/abs/2004.05718) in PyTorch, DGL and PyTorch Geometric. *Update: now you can find PNA directly integrated in both [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.PNAConv) and [DGL](https://docs.dgl.ai/generated/dgl.nn.pytorch.conv.PNAConv.html)!* ![symbol](./multitask_benchmark/images/symbol.png) ## Overview We provide the implementation of the Principal Neighbourhood Aggregation (PNA) in PyTorch, DGL and PyTorch Geometric frameworks, along with scripts to generate and run the multitask benchmarks, scripts for running real-world benchmarks, a flexible PyTorch GNN framework and implementations of the other models used for comparison. The repository is organised as follows: - `models` contains: - `pytorch` contains the various GNN models implemented in PyTorch: - the implementation of the aggregators, the scalers and the PNA layer (`pna`) - the flexible GNN framework that can be used with any type of graph convolutions (`gnn_framework.py`) - implementations of the other GNN models used for comparison in the paper, namely GCN, GAT, GIN and MPNN - `dgl` contains the PNA model implemented via the [DGL library](https://www.dgl.ai/): aggregators, scalers, and layer. - `pytorch_geometric` contains the PNA model implemented via the [PyTorch Geometric library](https://pytorch-geometric.readthedocs.io/): aggregators, scalers, and layer. - `layers.py` contains general NN layers used by the various models - `multi_task` contains various scripts to recreate the multi_task benchmark along with the files used to train the various models. In `multi_task/README.md` we detail the instructions for the generation and training hyperparameters tuned. - `real_world` contains various scripts from [Benchmarking GNNs](https://github.com/graphdeeplearning/benchmarking-gnns) to download the real-world benchmarks and train the PNA on them. In `real_world/README.md` we provide instructions for the generation and training hyperparameters tuned. ![results](./multitask_benchmark/images/results.png) ## Reference ``` @inproceedings{corso2020pna, title = {Principal Neighbourhood Aggregation for Graph Nets}, author = {Corso, Gabriele and Cavalleri, Luca and Beaini, Dominique and Li\`{o}, Pietro and Veli\v{c}kovi\'{c}, Petar}, booktitle = {Advances in Neural Information Processing Systems}, year = {2020} } ``` ## License MIT ## Acknowledgements The authors would like to thank Saro Passaro for running some of the tests presented in this repository and Giorgos Bouritsas, Fabrizio Frasca, Leonardo Cotta, Zhanghao Wu, Zhanqiu Zhang and George Watkins for pointing out some issues with the code. ================================================ FILE: models/dgl/aggregators.py ================================================ import torch EPS = 1e-5 def aggregate_mean(h): return torch.mean(h, dim=1) def aggregate_max(h): return torch.max(h, dim=1)[0] def aggregate_min(h): return torch.min(h, dim=1)[0] def aggregate_std(h): return torch.sqrt(aggregate_var(h) + EPS) def aggregate_var(h): h_mean_squares = torch.mean(h * h, dim=-2) h_mean = torch.mean(h, dim=-2) var = torch.relu(h_mean_squares - h_mean * h_mean) return var def aggregate_moment(h, n=3): # for each node (E[(X-E[X])^n])^{1/n} # EPS is added to the absolute value of expectation before taking the nth root for stability h_mean = torch.mean(h, dim=1, keepdim=True) h_n = torch.mean(torch.pow(h - h_mean, n)) rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + EPS, 1. / n) return rooted_h_n def aggregate_moment_3(h): return aggregate_moment(h, n=3) def aggregate_moment_4(h): return aggregate_moment(h, n=4) def aggregate_moment_5(h): return aggregate_moment(h, n=5) def aggregate_sum(h): return torch.sum(h, dim=1) AGGREGATORS = {'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min, 'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3, 'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5} ================================================ FILE: models/dgl/pna_layer.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import dgl.function as fn from .aggregators import AGGREGATORS from models.layers import MLP, FCLayer from .scalers import SCALERS """ PNA: Principal Neighbourhood Aggregation Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic https://arxiv.org/abs/2004.05718 """ class PNATower(nn.Module): def __init__(self, in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d, pretrans_layers, posttrans_layers, edge_features, edge_dim): super().__init__() self.dropout = dropout self.graph_norm = graph_norm self.batch_norm = batch_norm self.edge_features = edge_features self.batchnorm_h = nn.BatchNorm1d(out_dim) self.aggregators = aggregators self.scalers = scalers self.pretrans = MLP(in_size=2 * in_dim + (edge_dim if edge_features else 0), hidden_size=in_dim, out_size=in_dim, layers=pretrans_layers, mid_activation='relu', last_activation='none') self.posttrans = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim, out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') self.avg_d = avg_d def pretrans_edges(self, edges): if self.edge_features: z2 = torch.cat([edges.src['h'], edges.dst['h'], edges.data['ef']], dim=1) else: z2 = torch.cat([edges.src['h'], edges.dst['h']], dim=1) return {'e': self.pretrans(z2)} def message_func(self, edges): return {'e': edges.data['e']} def reduce_func(self, nodes): h = nodes.mailbox['e'] D = h.shape[-2] h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1) h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) return {'h': h} def posttrans_nodes(self, nodes): return self.posttrans(nodes.data['h']) def forward(self, g, h, e, snorm_n): g.ndata['h'] = h if self.edge_features: # add the edges information only if edge_features = True g.edata['ef'] = e # pretransformation g.apply_edges(self.pretrans_edges) # aggregation g.update_all(self.message_func, self.reduce_func) h = torch.cat([h, g.ndata['h']], dim=1) # posttransformation h = self.posttrans(h) # graph and batch normalization if self.graph_norm: h = h * snorm_n if self.batch_norm: h = self.batchnorm_h(h) h = F.dropout(h, self.dropout, training=self.training) return h class PNALayer(nn.Module): def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, graph_norm, batch_norm, towers=1, pretrans_layers=1, posttrans_layers=1, divide_input=True, residual=False, edge_features=False, edge_dim=0): """ :param in_dim: size of the input per node :param out_dim: size of the output per node :param aggregators: set of aggregation function identifiers :param scalers: set of scaling functions identifiers :param avg_d: average degree of nodes in the training set, used by scalers to normalize :param dropout: dropout used :param graph_norm: whether to use graph normalisation :param batch_norm: whether to use batch normalisation :param towers: number of towers to use :param pretrans_layers: number of layers in the transformation before the aggregation :param posttrans_layers: number of layers in the transformation after the aggregation :param divide_input: whether the input features should be split between towers or not :param residual: whether to add a residual connection :param edge_features: whether to use the edge features :param edge_dim: size of the edge features """ super().__init__() assert ((not divide_input) or in_dim % towers == 0), "if divide_input is set the number of towers has to divide in_dim" assert (out_dim % towers == 0), "the number of towers has to divide the out_dim" assert avg_d is not None # retrieve the aggregators and scalers functions aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()] scalers = [SCALERS[scale] for scale in scalers.split()] self.divide_input = divide_input self.input_tower = in_dim // towers if divide_input else in_dim self.output_tower = out_dim // towers self.in_dim = in_dim self.out_dim = out_dim self.edge_features = edge_features self.residual = residual if in_dim != out_dim: self.residual = False # convolution self.towers = nn.ModuleList() for _ in range(towers): self.towers.append(PNATower(in_dim=self.input_tower, out_dim=self.output_tower, aggregators=aggregators, scalers=scalers, avg_d=avg_d, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, batch_norm=batch_norm, dropout=dropout, graph_norm=graph_norm, edge_features=edge_features, edge_dim=edge_dim)) # mixing network self.mixing_network = FCLayer(out_dim, out_dim, activation='LeakyReLU') def forward(self, g, h, e, snorm_n): h_in = h # for residual connection if self.divide_input: h_cat = torch.cat( [tower(g, h[:, n_tower * self.input_tower: (n_tower + 1) * self.input_tower], e, snorm_n) for n_tower, tower in enumerate(self.towers)], dim=1) else: h_cat = torch.cat([tower(g, h, e, snorm_n) for tower in self.towers], dim=1) h_out = self.mixing_network(h_cat) if self.residual: h_out = h_in + h_out # residual connection return h_out def __repr__(self): return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim) class PNASimpleLayer(nn.Module): def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, batch_norm, residual, posttrans_layers=1): """ A simpler version of PNA layer that simply aggregates the neighbourhood (similar to GCN and GIN), without using the pretransformation or the tower mechanisms of the MPNN. It does not support edge features. :param in_dim: size of the input per node :param out_dim: size of the output per node :param aggregators: set of aggregation function identifiers :param scalers: set of scaling functions identifiers :param avg_d: average degree of nodes in the training set, used by scalers to normalize :param dropout: dropout used :param batch_norm: whether to use batch normalisation :param posttrans_layers: number of layers in the transformation after the aggregation """ super().__init__() # retrieve the aggregators and scalers functions aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()] scalers = [SCALERS[scale] for scale in scalers.split()] self.aggregators = aggregators self.scalers = scalers self.in_dim = in_dim self.out_dim = out_dim self.dropout = dropout self.batch_norm = batch_norm self.residual = residual self.batchnorm_h = nn.BatchNorm1d(out_dim) self.posttrans = MLP(in_size=(len(aggregators) * len(scalers)) * in_dim, hidden_size=out_dim, out_size=out_dim, layers=posttrans_layers, mid_activation='relu', last_activation='none') self.avg_d = avg_d def reduce_func(self, nodes): h = nodes.mailbox['m'] D = h.shape[-2] h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1) h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1) return {'h': h} def forward(self, g, h): h_in = h g.ndata['h'] = h # aggregation g.update_all(fn.copy_u('h', 'm'), self.reduce_func) h = g.ndata['h'] # posttransformation h = self.posttrans(h) # batch normalization and residual if self.batch_norm: h = self.batchnorm_h(h) h = F.relu(h) if self.residual: h = h_in + h h = F.dropout(h, self.dropout, training=self.training) return h def __repr__(self): return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim) ================================================ FILE: models/dgl/scalers.py ================================================ import torch import numpy as np # each scaler is a function that takes as input X (B x N x Din), adj (B x N x N) and # avg_d (dictionary containing averages over training set) and returns X_scaled (B x N x Din) as output def scale_identity(h, D=None, avg_d=None): return h def scale_amplification(h, D, avg_d): # log(D + 1) / d * h where d is the average of the ``log(D + 1)`` in the training set return h * (np.log(D + 1) / avg_d["log"]) def scale_attenuation(h, D, avg_d): # (log(D + 1))^-1 / d * X where d is the average of the ``log(D + 1))^-1`` in the training set return h * (avg_d["log"] / np.log(D + 1)) SCALERS = {'identity': scale_identity, 'amplification': scale_amplification, 'attenuation': scale_attenuation} ================================================ FILE: models/layers.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F SUPPORTED_ACTIVATION_MAP = {'ReLU', 'Sigmoid', 'Tanh', 'ELU', 'SELU', 'GLU', 'LeakyReLU', 'Softplus', 'None'} def get_activation(activation): """ returns the activation function represented by the input string """ if activation and callable(activation): # activation is already a function return activation # search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()] assert len(activation) == 1 and isinstance(activation[0], str), 'Unhandled activation function' activation = activation[0] if activation.lower() == 'none': return None return vars(torch.nn.modules.activation)[activation]() class Set2Set(torch.nn.Module): r""" Set2Set global pooling operator from the `"Order Matters: Sequence to sequence for sets" `_ paper. This pooling layer performs the following operation .. math:: \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice the dimensionality as the input. Arguments --------- input_dim: int Size of each input sample. hidden_dim: int, optional the dim of set representation which corresponds to the input dim of the LSTM in Set2Set. This is typically the sum of the input dim and the lstm output dim. If not provided, it will be set to :obj:`input_dim*2` steps: int, optional Number of iterations :math:`T`. If not provided, the number of nodes will be used. num_layers : int, optional Number of recurrent layers (e.g., :obj:`num_layers=2` would mean stacking two LSTMs together) (Default, value = 1) """ def __init__(self, nin, nhid=None, steps=None, num_layers=1, activation=None, device='cpu'): super(Set2Set, self).__init__() self.steps = steps self.nin = nin self.nhid = nin * 2 if nhid is None else nhid if self.nhid <= self.nin: raise ValueError('Set2Set hidden_dim should be larger than input_dim') # the hidden is a concatenation of weighted sum of embedding and LSTM output self.lstm_output_dim = self.nhid - self.nin self.num_layers = num_layers self.lstm = nn.LSTM(self.nhid, self.nin, num_layers=num_layers, batch_first=True).to(device) self.softmax = nn.Softmax(dim=1) def forward(self, x): r""" Applies the pooling on input tensor x Arguments ---------- x: torch.FloatTensor Input tensor of size (B, N, D) Returns ------- x: `torch.FloatTensor` Tensor resulting from the set2set pooling operation. """ batch_size = x.shape[0] n = self.steps or x.shape[1] h = (x.new_zeros((self.num_layers, batch_size, self.nin)), x.new_zeros((self.num_layers, batch_size, self.nin))) q_star = x.new_zeros(batch_size, 1, self.nhid) for i in range(n): # q: batch_size x 1 x input_dim q, h = self.lstm(q_star, h) # e: batch_size x n x 1 e = torch.matmul(x, torch.transpose(q, 1, 2)) a = self.softmax(e) r = torch.sum(a * x, dim=1, keepdim=True) q_star = torch.cat([q, r], dim=-1) return torch.squeeze(q_star, dim=1) class FCLayer(nn.Module): r""" A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module. The order in which transformations are applied is: #. Dense Layer #. Activation #. Dropout (if applicable) #. Batch Normalization (if applicable) Arguments ---------- in_size: int Input dimension of the layer (the torch.nn.Linear) out_size: int Output dimension of the layer. dropout: float, optional The ratio of units to dropout. No dropout by default. (Default value = 0.) activation: str or callable, optional Activation function to use. (Default value = relu) b_norm: bool, optional Whether to use batch normalization (Default value = False) bias: bool, optional Whether to enable bias in for the linear layer. (Default value = True) init_fn: callable, optional Initialization function to use for the weight of the layer. Default is :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` with :math:`k=\frac{1}{ \text{in_size}}` (Default value = None) Attributes ---------- dropout: int The ratio of units to dropout. b_norm: int Whether to use batch normalization linear: torch.nn.Linear The linear layer activation: the torch.nn.Module The activation layer init_fn: function Initialization function used for the weight of the layer in_size: int Input dimension of the linear layer out_size: int Output dimension of the linear layer """ def __init__(self, in_size, out_size, activation='relu', dropout=0., b_norm=False, bias=True, init_fn=None, device='cpu'): super(FCLayer, self).__init__() self.__params = locals() del self.__params['__class__'] del self.__params['self'] self.in_size = in_size self.out_size = out_size self.bias = bias self.linear = nn.Linear(in_size, out_size, bias=bias).to(device) self.dropout = None self.b_norm = None if dropout: self.dropout = nn.Dropout(p=dropout) if b_norm: self.b_norm = nn.BatchNorm1d(out_size).to(device) self.activation = get_activation(activation) self.init_fn = nn.init.xavier_uniform_ self.reset_parameters() def reset_parameters(self, init_fn=None): init_fn = init_fn or self.init_fn if init_fn is not None: init_fn(self.linear.weight, 1 / self.in_size) if self.bias: self.linear.bias.data.zero_() def forward(self, x): h = self.linear(x) if self.activation is not None: h = self.activation(h) if self.dropout is not None: h = self.dropout(h) if self.b_norm is not None: if h.shape[1] != self.out_size: h = self.b_norm(h.transpose(1, 2)).transpose(1, 2) else: h = self.b_norm(h) return h def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_size) + ' -> ' \ + str(self.out_size) + ')' class MLP(nn.Module): """ Simple multi-layer perceptron, built of a series of FCLayers """ def __init__(self, in_size, hidden_size, out_size, layers, mid_activation='relu', last_activation='none', dropout=0., mid_b_norm=False, last_b_norm=False, device='cpu'): super(MLP, self).__init__() self.in_size = in_size self.hidden_size = hidden_size self.out_size = out_size self.fully_connected = nn.ModuleList() if layers <= 1: self.fully_connected.append(FCLayer(in_size, out_size, activation=last_activation, b_norm=last_b_norm, device=device, dropout=dropout)) else: self.fully_connected.append(FCLayer(in_size, hidden_size, activation=mid_activation, b_norm=mid_b_norm, device=device, dropout=dropout)) for _ in range(layers - 2): self.fully_connected.append(FCLayer(hidden_size, hidden_size, activation=mid_activation, b_norm=mid_b_norm, device=device, dropout=dropout)) self.fully_connected.append(FCLayer(hidden_size, out_size, activation=last_activation, b_norm=last_b_norm, device=device, dropout=dropout)) def forward(self, x): for fc in self.fully_connected: x = fc(x) return x def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_size) + ' -> ' \ + str(self.out_size) + ')' class GRU(nn.Module): """ Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself """ def __init__(self, input_size, hidden_size, device): super(GRU, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size).to(device) def forward(self, x, y): """ :param x: shape: (B, N, Din) where Din <= input_size (difference is padded) :param y: shape: (B, N, Dh) where Dh <= hidden_size (difference is padded) :return: shape: (B, N, Dh) """ assert (x.shape[-1] <= self.input_size and y.shape[-1] <= self.hidden_size) (B, N, _) = x.shape x = x.reshape(1, B * N, -1).contiguous() y = y.reshape(1, B * N, -1).contiguous() # padding if necessary if x.shape[-1] < self.input_size: x = F.pad(input=x, pad=[0, self.input_size - x.shape[-1]], mode='constant', value=0) if y.shape[-1] < self.hidden_size: y = F.pad(input=y, pad=[0, self.hidden_size - y.shape[-1]], mode='constant', value=0) x = self.gru(x, y)[1] x = x.reshape(B, N, -1) return x class S2SReadout(nn.Module): """ Performs a Set2Set aggregation of all the graph nodes' features followed by a series of fully connected layers """ def __init__(self, in_size, hidden_size, out_size, fc_layers=3, device='cpu', final_activation='relu'): super(S2SReadout, self).__init__() # set2set aggregation self.set2set = Set2Set(in_size, device=device) # fully connected layers self.mlp = MLP(in_size=2 * in_size, hidden_size=hidden_size, out_size=out_size, layers=fc_layers, mid_activation="relu", last_activation=final_activation, mid_b_norm=True, last_b_norm=False, device=device) def forward(self, x): x = self.set2set(x) return self.mlp(x) ================================================ FILE: models/pytorch/gat/layer.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class GATHead(nn.Module): def __init__(self, in_features, out_features, alpha, activation=True, device='cpu'): super(GATHead, self).__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.W = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device)) self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1), device=device)) self.leakyrelu = nn.LeakyReLU(alpha) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.W.data, gain=0.1414) nn.init.xavier_uniform_(self.a.data, gain=0.1414) def forward(self, input, adj): h = torch.matmul(input, self.W) (B, N, _) = adj.shape a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=1)\ .view(B, N, -1, 2 * self.out_features) e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3)) zero_vec = -9e15 * torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) attention = F.softmax(attention, dim=1) h_prime = torch.matmul(attention, h) if self.activation: return F.elu(h_prime) else: return h_prime def __repr__(self): return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' class GATLayer(nn.Module): """ Graph Attention Layer, GAT paper at https://arxiv.org/abs/1710.10903 Implementation inspired by https://github.com/Diego999/pyGAT """ def __init__(self, in_features, out_features, alpha, nheads=1, activation=True, device='cpu'): """ :param in_features: size of the input per node :param out_features: size of the output per node :param alpha: slope of the leaky relu :param nheads: number of attention heads :param activation: whether to apply a non-linearity :param device: device used for computation """ super(GATLayer, self).__init__() assert (out_features % nheads == 0) self.input_head = in_features self.output_head = out_features // nheads self.heads = nn.ModuleList() for _ in range(nheads): self.heads.append(GATHead(in_features=self.input_head, out_features=self.output_head, alpha=alpha, activation=activation, device=device)) def forward(self, input, adj): y = torch.cat([head(input, adj) for head in self.heads], dim=2) return y def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')' ================================================ FILE: models/pytorch/gcn/layer.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F class GCNLayer(nn.Module): """ GCN layer, similar to https://arxiv.org/abs/1609.02907 Implementation inspired by https://github.com/tkipf/pygcn """ def __init__(self, in_features, out_features, bias=True, device='cpu'): """ :param in_features: size of the input per node :param out_features: size of the output per node :param bias: whether to add a learnable bias before the activation :param device: device used for computation """ super(GCNLayer, self).__init__() self.in_features = in_features self.out_features = out_features self.device = device self.W = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device)) if bias: self.b = nn.Parameter(torch.zeros(out_features, device=device)) else: self.register_parameter('b', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.W.size(1)) self.W.data.uniform_(-stdv, stdv) if self.b is not None: self.b.data.uniform_(-stdv, stdv) def forward(self, X, adj): (B, N, _) = adj.shape # linear transformation XW = torch.matmul(X, self.W) # normalised mean aggregation adj = adj + torch.eye(N, device=self.device).unsqueeze(0) rD = torch.mul(torch.pow(torch.sum(adj, -1, keepdim=True), -0.5), torch.eye(N, device=self.device).unsqueeze(0)) # D^{-1/2] adj = torch.matmul(torch.matmul(rD, adj), rD) # D^{-1/2] A' D^{-1/2] y = torch.bmm(adj, XW) if self.b is not None: y = y + self.b return F.leaky_relu(y) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')' ================================================ FILE: models/pytorch/gin/layer.py ================================================ import torch import torch.nn as nn from models.layers import MLP class GINLayer(nn.Module): """ Graph Isomorphism Network layer, similar to https://arxiv.org/abs/1810.00826 """ def __init__(self, in_features, out_features, fc_layers=2, device='cpu'): """ :param in_features: size of the input per node :param out_features: size of the output per node :param fc_layers: number of fully connected layers after the sum aggregator :param device: device used for computation """ super(GINLayer, self).__init__() self.device = device self.in_features = in_features self.out_features = out_features self.epsilon = nn.Parameter(torch.zeros(size=(1,), device=device)) self.post_transformation = MLP(in_size=in_features, hidden_size=max(in_features, out_features), out_size=out_features, layers=fc_layers, mid_activation='relu', last_activation='relu', mid_b_norm=True, last_b_norm=False, device=device) self.reset_parameters() def reset_parameters(self): self.epsilon.data.fill_(0.1) def forward(self, input, adj): (B, N, _) = adj.shape # sum aggregation mod_adj = adj + torch.eye(N, device=self.device).unsqueeze(0) * (1 + self.epsilon) support = torch.matmul(mod_adj, input) # post-aggregation transformation return self.post_transformation(support) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')' ================================================ FILE: models/pytorch/gnn_framework.py ================================================ import types import torch import torch.nn as nn import torch.nn.functional as F from models.layers import GRU, S2SReadout, MLP class GNN(nn.Module): def __init__(self, nfeat, nhid, nodes_out, graph_out, dropout, conv_layers=2, fc_layers=3, first_conv_descr=None, middle_conv_descr=None, final_activation='LeakyReLU', skip=False, gru=False, fixed=False, variable=False, device='cpu'): """ :param nfeat: number of input features per node :param nhid: number of hidden features per node :param nodes_out: number of nodes' labels :param graph_out: number of graph labels :param dropout: dropout value :param conv_layers: if variable, conv_layers should be a function : adj -> int, otherwise an int :param fc_layers: number of fully connected layers before the labels :param first_conv_descr: dict or SimpleNamespace: "type"-> type of layer, "args" -> dict of calling args :param middle_conv_descr: dict or SimpleNamespace : "type"-> type of layer, "args" -> dict of calling args :param final_activation: activation to be used on the last fc layer before the labels :param skip: whether to use skip connections feeding to the readout :param gru: whether to use a shared GRU after each convolution :param fixed: whether to reuse the same middle convolutional layer multiple times :param variable: whether the number of convolutional layers is variable or fixed :param device: device used for computation """ super(GNN, self).__init__() if variable: assert callable(conv_layers), "conv_layers should be a function from adjacency matrix to int" assert fixed, "With a variable number of layers they must be fixed" assert not skip, "cannot have skip and fixed at the same time" else: assert type(conv_layers) == int, "conv_layers should be an int" assert conv_layers > 0, "conv_layers should be greater than 0" if type(first_conv_descr) == dict: first_conv_descr = types.SimpleNamespace(**first_conv_descr) assert type(first_conv_descr) == types.SimpleNamespace, "first_conv_descr should be dict or SimpleNamespace" if type(first_conv_descr.args) == dict: first_conv_descr.args = types.SimpleNamespace(**first_conv_descr.args) assert type(first_conv_descr.args) == types.SimpleNamespace, \ "first_conv_descr.args should be either a dict or a SimpleNamespace" if type(middle_conv_descr) == dict: middle_conv_descr = types.SimpleNamespace(**middle_conv_descr) assert type(middle_conv_descr) == types.SimpleNamespace, "middle_conv_descr should be dict or SimpleNamespace" if type(middle_conv_descr.args) == dict: middle_conv_descr.args = types.SimpleNamespace(**middle_conv_descr.args) assert type(middle_conv_descr.args) == types.SimpleNamespace, \ "middle_conv_descr.args should be either a dict or a SimpleNamespace" self.dropout = dropout self.conv_layers = nn.ModuleList() self.skip = skip self.fixed = fixed self.variable = variable self.n_fixed_conv = conv_layers self.gru = GRU(input_size=nhid, hidden_size=nhid, device=device) if gru else None # first graph convolution first_conv_descr.args.in_features = nfeat first_conv_descr.args.out_features = nhid first_conv_descr.args.device = device self.conv_layers.append(first_conv_descr.layer_type(**vars(first_conv_descr.args))) # middle graph convolutions middle_conv_descr.args.in_features = nhid middle_conv_descr.args.out_features = nhid middle_conv_descr.args.device = device for l in range(1 if fixed else conv_layers - 1): self.conv_layers.append( middle_conv_descr.layer_type(**vars(middle_conv_descr.args))) n_conv_out = nfeat + conv_layers * nhid if skip else nhid # nodes output: fully connected layers self.nodes_read_out = MLP(in_size=n_conv_out, hidden_size=n_conv_out, out_size=nodes_out, layers=fc_layers, mid_activation="LeakyReLU", last_activation=final_activation, device=device) # graph output: S2S readout self.graph_read_out = S2SReadout(n_conv_out, n_conv_out, graph_out, fc_layers=fc_layers, device=device, final_activation=final_activation) def forward(self, x, adj): # graph convolutions skip_connections = [x] if self.skip else None n_layers = self.n_fixed_conv(adj) if self.variable else self.n_fixed_conv conv_layers = [self.conv_layers[0]] + ([self.conv_layers[1]] * (n_layers - 1)) if self.fixed else self.conv_layers for layer, conv in enumerate(conv_layers): y = conv(x, adj) x = y if self.gru is None else self.gru(x, y) if self.skip: skip_connections.append(x) # dropout at all layers but the last if layer != n_layers - 1: x = F.dropout(x, self.dropout, training=self.training) if self.skip: x = torch.cat(skip_connections, dim=2) # readout output return (self.nodes_read_out(x), self.graph_read_out(x)) ================================================ FILE: models/pytorch/pna/aggregators.py ================================================ import math import torch EPS = 1e-5 # each aggregator is a function taking as input X (B x N x N x Din), adj (B x N x N), self_loop and device and # returning the aggregated value of X (B x N x Din) for each dimension def aggregate_identity(X, adj, self_loop=False, device='cpu'): # Y is corresponds to the elements of the main diagonal of X (_, N, N, _) = X.shape Y = torch.sum(torch.mul(X, torch.eye(N).reshape(1, N, N, 1)), dim=2) return Y def aggregate_mean(X, adj, self_loop=False, device='cpu'): # D^{-1} A * X i.e. the mean of the neighbours if self_loop: # add self connections (B, N, _) = adj.shape adj = adj + torch.eye(N, device=device).unsqueeze(0) D = torch.sum(adj, -1, keepdim=True) X_sum = torch.sum(torch.mul(X, adj.unsqueeze(-1)), dim=2) X_mean = torch.div(X_sum, D) return X_mean def aggregate_max(X, adj, min_value=-math.inf, self_loop=False, device='cpu'): (B, N, N, Din) = X.shape if self_loop: # add self connections adj = adj + torch.eye(N, device=device).unsqueeze(0) adj = adj.unsqueeze(-1) # adding extra dimension M = torch.where(adj > 0.0, X, torch.tensor(min_value, device=device)) max = torch.max(M, -3)[0] return max def aggregate_min(X, adj, max_value=math.inf, self_loop=False, device='cpu'): (B, N, N, Din) = X.shape if self_loop: # add self connections adj = adj + torch.eye(N, device=device).unsqueeze(0) adj = adj.unsqueeze(-1) # adding extra dimension M = torch.where(adj > 0.0, X, torch.tensor(max_value, device=device)) min = torch.min(M, -3)[0] return min def aggregate_std(X, adj, self_loop=False, device='cpu'): # sqrt(relu(D^{-1} A X^2 - (D^{-1} A X)^2) + EPS) i.e. the standard deviation of the features of the neighbours # the EPS is added for the stability of the derivative of the square root std = torch.sqrt(aggregate_var(X, adj, self_loop, device) + EPS) # sqrt(mean_squares_X - mean_X^2) return std def aggregate_var(X, adj, self_loop=False, device='cpu'): # relu(D^{-1} A X^2 - (D^{-1} A X)^2) i.e. the variance of the features of the neighbours if self_loop: # add self connections (B, N, _) = adj.shape adj = adj + torch.eye(N, device=device).unsqueeze(0) D = torch.sum(adj, -1, keepdim=True) X_sum_squares = torch.sum(torch.mul(torch.mul(X, X), adj.unsqueeze(-1)), dim=2) X_mean_squares = torch.div(X_sum_squares, D) # D^{-1} A X^2 X_mean = aggregate_mean(X, adj) # D^{-1} A X var = torch.relu(X_mean_squares - torch.mul(X_mean, X_mean)) # relu(mean_squares_X - mean_X^2) return var def aggregate_sum(X, adj, self_loop=False, device='cpu'): # A * X i.e. the mean of the neighbours if self_loop: # add self connections (B, N, _) = adj.shape adj = adj + torch.eye(N, device=device).unsqueeze(0) X_sum = torch.sum(torch.mul(X, adj.unsqueeze(-1)), dim=2) return X_sum def aggregate_normalised_mean(X, adj, self_loop=False, device='cpu'): # D^{-1/2] A D^{-1/2] X (B, N, N, _) = X.shape if self_loop: # add self connections adj = adj + torch.eye(N, device=device).unsqueeze(0) rD = torch.mul(torch.pow(torch.sum(adj, -1, keepdim=True), -0.5), torch.eye(N, device=device) .unsqueeze(0).repeat(B, 1, 1)) # D^{-1/2] adj = torch.matmul(torch.matmul(rD, adj), rD) # D^{-1/2] A' D^{-1/2] X_sum = torch.sum(torch.mul(X, adj.unsqueeze(-1)), dim=2) return X_sum def aggregate_softmax(X, adj, self_loop=False, device='cpu'): # for each node sum_i(x_i*exp(x_i)/sum_j(exp(x_j)) where x_i and x_j vary over the neighbourhood of the node (B, N, N, Din) = X.shape if self_loop: # add self connections adj = adj + torch.eye(N, device=device).unsqueeze(0) X_exp = torch.exp(X) adj = adj.unsqueeze(-1) # adding extra dimension X_exp = torch.mul(X_exp, adj) X_sum = torch.sum(X_exp, dim=2, keepdim=True) softmax = torch.sum(torch.mul(torch.div(X_exp, X_sum), X), dim=2) return softmax def aggregate_softmin(X, adj, self_loop=False, device='cpu'): # for each node sum_i(x_i*exp(-x_i)/sum_j(exp(-x_j)) where x_i and x_j vary over the neighbourhood of the node return -aggregate_softmax(-X, adj, self_loop=self_loop, device=device) def aggregate_moment(X, adj, self_loop=False, device='cpu', n=3): # for each node (E[(X-E[X])^n])^{1/n} # EPS is added to the absolute value of expectation before taking the nth root for stability if self_loop: # add self connections (B, N, _) = adj.shape adj = adj + torch.eye(N, device=device).unsqueeze(0) D = torch.sum(adj, -1, keepdim=True) X_mean = aggregate_mean(X, adj, self_loop=self_loop, device=device) X_n = torch.div(torch.sum(torch.mul(torch.pow(X - X_mean.unsqueeze(2), n), adj.unsqueeze(-1)), dim=2), D) rooted_X_n = torch.sign(X_n) * torch.pow(torch.abs(X_n) + EPS, 1. / n) return rooted_X_n def aggregate_moment_3(X, adj, self_loop=False, device='cpu'): return aggregate_moment(X, adj, self_loop=self_loop, device=device, n=3) def aggregate_moment_4(X, adj, self_loop=False, device='cpu'): return aggregate_moment(X, adj, self_loop=self_loop, device=device, n=4) def aggregate_moment_5(X, adj, self_loop=False, device='cpu'): return aggregate_moment(X, adj, self_loop=self_loop, device=device, n=5) AGGREGATORS = {'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min, 'identity': aggregate_identity, 'std': aggregate_std, 'var': aggregate_var, 'normalised_mean': aggregate_normalised_mean, 'softmax': aggregate_softmax, 'softmin': aggregate_softmin, 'moment3': aggregate_moment_3, 'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5} ================================================ FILE: models/pytorch/pna/layer.py ================================================ import torch import torch.nn as nn from models.pytorch.pna.aggregators import AGGREGATORS from models.pytorch.pna.scalers import SCALERS from models.layers import FCLayer, MLP class PNATower(nn.Module): def __init__(self, in_features, out_features, aggregators, scalers, avg_d, self_loop, pretrans_layers, posttrans_layers, device): """ :param in_features: size of the input per node of the tower :param out_features: size of the output per node of the tower :param aggregators: set of aggregation functions each taking as input X (B x N x N x Din), adj (B x N x N), self_loop and device :param scalers: set of scaling functions each taking as input X (B x N x Din), adj (B x N x N) and avg_d """ super(PNATower, self).__init__() self.device = device self.in_features = in_features self.out_features = out_features self.aggregators = aggregators self.scalers = scalers self.self_loop = self_loop self.pretrans = MLP(in_size=2 * self.in_features, hidden_size=self.in_features, out_size=self.in_features, layers=pretrans_layers, mid_activation='relu', last_activation='none') self.posttrans = MLP(in_size=(len(aggregators) * len(scalers) + 1) * self.in_features, hidden_size=self.out_features, out_size=self.out_features, layers=posttrans_layers, mid_activation='relu', last_activation='none') self.avg_d = avg_d def forward(self, input, adj): (B, N, _) = adj.shape # pre-aggregation transformation h_i = input.unsqueeze(2).repeat(1, 1, N, 1) h_j = input.unsqueeze(1).repeat(1, N, 1, 1) h_cat = torch.cat([h_i, h_j], dim=3) h_mod = self.pretrans(h_cat) # aggregation m = torch.cat([aggregate(h_mod, adj, self_loop=self.self_loop, device=self.device) for aggregate in self.aggregators], dim=2) m = torch.cat([scale(m, adj, avg_d=self.avg_d) for scale in self.scalers], dim=2) # post-aggregation transformation m_cat = torch.cat([input, m], dim=2) out = self.posttrans(m_cat) return out def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')' class PNALayer(nn.Module): """ Implements a single convolutional layer of the Principal Neighbourhood Aggregation Networks as described in https://arxiv.org/abs/2004.05718 """ def __init__(self, in_features, out_features, aggregators, scalers, avg_d, towers=1, self_loop=False, pretrans_layers=1, posttrans_layers=1, divide_input=True, device='cpu'): """ :param in_features: size of the input per node :param out_features: size of the output per node :param aggregators: set of aggregation function identifiers :param scalers: set of scaling functions identifiers :param avg_d: average degree of nodes in the training set, used by scalers to normalize :param self_loop: whether to add a self loop in the adjacency matrix when aggregating :param pretrans_layers: number of layers in the transformation before the aggregation :param posttrans_layers: number of layers in the transformation after the aggregation :param divide_input: whether the input features should be split between towers or not :param device: device used for computation """ super(PNALayer, self).__init__() assert ((not divide_input) or in_features % towers == 0), "if divide_input is set the number of towers has to divide in_features" assert (out_features % towers == 0), "the number of towers has to divide the out_features" # retrieve the aggregators and scalers functions aggregators = [AGGREGATORS[aggr] for aggr in aggregators] scalers = [SCALERS[scale] for scale in scalers] self.divide_input = divide_input self.input_tower = in_features // towers if divide_input else in_features self.output_tower = out_features // towers # convolution self.towers = nn.ModuleList() for _ in range(towers): self.towers.append( PNATower(in_features=self.input_tower, out_features=self.output_tower, aggregators=aggregators, scalers=scalers, avg_d=avg_d, self_loop=self_loop, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers, device=device)) # mixing network self.mixing_network = FCLayer(out_features, out_features, activation='LeakyReLU') def forward(self, input, adj): # convolution if self.divide_input: y = torch.cat( [tower(input[:, :, n_tower * self.input_tower: (n_tower + 1) * self.input_tower], adj) for n_tower, tower in enumerate(self.towers)], dim=2) else: y = torch.cat([tower(input, adj) for tower in self.towers], dim=2) # mixing network return self.mixing_network(y) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')' ================================================ FILE: models/pytorch/pna/scalers.py ================================================ import torch # each scaler is a function that takes as input X (B x N x Din), adj (B x N x N) and # avg_d (dictionary containing averages over training set) and returns X_scaled (B x N x Din) as output def scale_identity(X, adj, avg_d=None): return X def scale_amplification(X, adj, avg_d=None): # log(D + 1) / d * X where d is the average of the ``log(D + 1)`` in the training set D = torch.sum(adj, -1) scale = (torch.log(D + 1) / avg_d["log"]).unsqueeze(-1) X_scaled = torch.mul(scale, X) return X_scaled def scale_attenuation(X, adj, avg_d=None): # (log(D + 1))^-1 / d * X where d is the average of the ``log(D + 1))^-1`` in the training set D = torch.sum(adj, -1) scale = (avg_d["log"] / torch.log(D + 1)).unsqueeze(-1) X_scaled = torch.mul(scale, X) return X_scaled def scale_linear(X, adj, avg_d=None): # d^{-1} D X where d is the average degree in the training set D = torch.sum(adj, -1, keepdim=True) X_scaled = D * X / avg_d["lin"] return X_scaled def scale_inverse_linear(X, adj, avg_d=None): # d D^{-1} X where d is the average degree in the training set D = torch.sum(adj, -1, keepdim=True) X_scaled = avg_d["lin"] * X / D return X_scaled SCALERS = {'identity': scale_identity, 'amplification': scale_amplification, 'attenuation': scale_attenuation, 'linear': scale_linear, 'inverse_linear': scale_inverse_linear} ================================================ FILE: models/pytorch_geometric/aggregators.py ================================================ import torch from torch import Tensor from torch_scatter import scatter from typing import Optional # Implemented with the help of Matthias Fey, author of PyTorch Geometric # For an example see https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pna.py def aggregate_sum(src: Tensor, index: Tensor, dim_size: Optional[int]): return scatter(src, index, 0, None, dim_size, reduce='sum') def aggregate_mean(src: Tensor, index: Tensor, dim_size: Optional[int]): return scatter(src, index, 0, None, dim_size, reduce='mean') def aggregate_min(src: Tensor, index: Tensor, dim_size: Optional[int]): return scatter(src, index, 0, None, dim_size, reduce='min') def aggregate_max(src: Tensor, index: Tensor, dim_size: Optional[int]): return scatter(src, index, 0, None, dim_size, reduce='max') def aggregate_var(src, index, dim_size): mean = aggregate_mean(src, index, dim_size) mean_squares = aggregate_mean(src * src, index, dim_size) return mean_squares - mean * mean def aggregate_std(src, index, dim_size): return torch.sqrt(torch.relu(aggregate_var(src, index, dim_size)) + 1e-5) AGGREGATORS = { 'sum': aggregate_sum, 'mean': aggregate_mean, 'min': aggregate_min, 'max': aggregate_max, 'var': aggregate_var, 'std': aggregate_std, } ================================================ FILE: models/pytorch_geometric/example.py ================================================ import torch import torch.nn.functional as F from torch.nn import ModuleList from torch.nn import Sequential, ReLU, Linear from torch.optim.lr_scheduler import ReduceLROnPlateau from torch_geometric.utils import degree from ogb.graphproppred import PygGraphPropPredDataset, Evaluator from ogb.graphproppred.mol_encoder import AtomEncoder from torch_geometric.data import DataLoader from torch_geometric.nn import BatchNorm, global_mean_pool from models.pytorch_geometric.pna import PNAConvSimple dataset = PygGraphPropPredDataset(name="ogbg-molhiv") split_idx = dataset.get_idx_split() train_loader = DataLoader(dataset[split_idx["train"]], batch_size=128, shuffle=True) val_loader = DataLoader(dataset[split_idx["valid"]], batch_size=128, shuffle=False) test_loader = DataLoader(dataset[split_idx["test"]], batch_size=128, shuffle=False) # Compute in-degree histogram over training data. deg = torch.zeros(10, dtype=torch.long) for data in dataset[split_idx['train']]: d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel()) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.node_emb = AtomEncoder(emb_dim=80) aggregators = ['mean', 'min', 'max', 'std'] scalers = ['identity', 'amplification', 'attenuation'] self.convs = ModuleList() self.batch_norms = ModuleList() for _ in range(4): conv = PNAConvSimple(in_channels=80, out_channels=80, aggregators=aggregators, scalers=scalers, deg=deg, post_layers=1) self.convs.append(conv) self.batch_norms.append(BatchNorm(80)) self.mlp = Sequential(Linear(80, 40), ReLU(), Linear(40, 20), ReLU(), Linear(20, 1)) def forward(self, x, edge_index, edge_attr, batch): x = self.node_emb(x) for conv, batch_norm in zip(self.convs, self.batch_norms): h = F.relu(batch_norm(conv(x, edge_index, edge_attr))) x = h + x # residual# x = F.dropout(x, 0.3, training=self.training) x = global_mean_pool(x, batch) return self.mlp(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=3e-6) scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, min_lr=0.0001) def train(epoch): model.train() total_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, None, data.batch) loss = torch.nn.BCEWithLogitsLoss()(out.to(torch.float32), data.y.to(torch.float32)) loss.backward() total_loss += loss.item() * data.num_graphs optimizer.step() return total_loss / len(train_loader.dataset) @torch.no_grad() def test(loader): model.eval() evaluator = Evaluator(name='ogbg-molhiv') list_pred = [] list_labels = [] for data in loader: data = data.to(device) out = model(data.x, data.edge_index, None, data.batch) list_pred.append(out) list_labels.append(data.y) epoch_test_ROC = evaluator.eval({'y_pred': torch.cat(list_pred), 'y_true': torch.cat(list_labels)})['rocauc'] return epoch_test_ROC best = (0, 0) for epoch in range(1, 201): loss = train(epoch) val_roc = test(val_loader) test_roc = test(test_loader) scheduler.step(val_roc) print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_roc:.4f}, ' f'Test: {test_roc:.4f}') if val_roc > best[0]: best = (val_roc, test_roc) print(f'Best epoch val: {best[0]:.4f}, test: {best[1]:.4f}') ================================================ FILE: models/pytorch_geometric/pna.py ================================================ from typing import Optional, List, Dict from torch_geometric.typing import Adj, OptTensor import torch from torch import Tensor from torch.nn import ModuleList, Sequential, Linear, ReLU from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset from torch_geometric.utils import degree from models.pytorch_geometric.aggregators import AGGREGATORS from models.pytorch_geometric.scalers import SCALERS # Implemented with the help of Matthias Fey, author of PyTorch Geometric # For an example see https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pna.py class PNAConv(MessagePassing): r"""The Principal Neighbourhood Aggregation graph convolution operator from the `"Principal Neighbourhood Aggregation for Graph Nets" `_ paper .. math:: \bigoplus = \underbrace{\begin{bmatrix}I \\ S(D, \alpha=1) \\ S(D, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}}, in: .. math:: X_i^{(t+1)} = U \left( X_i^{(t)}, \underset{(j,i) \in E}{\bigoplus} M \left( X_i^{(t)}, X_j^{(t)} \right) \right) where :math:`M` and :math:`U` denote the MLP referred to with pretrans and posttrans respectively. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. aggregators (list of str): Set of aggregation function identifiers, namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"var"` and :obj:`"std"`. scalers: (list of str): Set of scaling function identifiers, namely :obj:`"identity"`, :obj:`"amplification"`, :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. deg (Tensor): Histogram of in-degrees of nodes in the training set, used by scalers to normalize. edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default :obj:`None`) towers (int, optional): Number of towers (default: :obj:`1`). pre_layers (int, optional): Number of transformation layers before aggregation (default: :obj:`1`). post_layers (int, optional): Number of transformation layers after aggregation (default: :obj:`1`). divide_input (bool, optional): Whether the input features should be split between towers or not (default: :obj:`False`). **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels: int, out_channels: int, aggregators: List[str], scalers: List[str], deg: Tensor, edge_dim: Optional[int] = None, towers: int = 1, pre_layers: int = 1, post_layers: int = 1, divide_input: bool = False, **kwargs): super(PNAConv, self).__init__(aggr=None, node_dim=0, **kwargs) if divide_input: assert in_channels % towers == 0 assert out_channels % towers == 0 self.in_channels = in_channels self.out_channels = out_channels self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators] self.scalers = [SCALERS[scale] for scale in scalers] self.edge_dim = edge_dim self.towers = towers self.divide_input = divide_input self.F_in = in_channels // towers if divide_input else in_channels self.F_out = self.out_channels // towers deg = deg.to(torch.float) total_no_vertices = deg.sum() bin_degrees = torch.arange(len(deg)) self.avg_deg: Dict[str, float] = { 'lin': ((bin_degrees * deg).sum() / total_no_vertices).item(), 'log': (((bin_degrees + 1).log() * deg).sum() / total_no_vertices).item(), 'exp': ((bin_degrees.exp() * deg).sum() / total_no_vertices).item(), } if self.edge_dim is not None: self.edge_encoder = Linear(edge_dim, self.F_in) self.pre_nns = ModuleList() self.post_nns = ModuleList() for _ in range(towers): modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)] for _ in range(pre_layers - 1): modules += [ReLU()] modules += [Linear(self.F_in, self.F_in)] self.pre_nns.append(Sequential(*modules)) in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in modules = [Linear(in_channels, self.F_out)] for _ in range(post_layers - 1): modules += [ReLU()] modules += [Linear(self.F_out, self.F_out)] self.post_nns.append(Sequential(*modules)) self.lin = Linear(out_channels, out_channels) self.reset_parameters() def reset_parameters(self): if self.edge_dim is not None: self.edge_encoder.reset_parameters() for nn in self.pre_nns: reset(nn) for nn in self.post_nns: reset(nn) self.lin.reset_parameters() def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: if self.divide_input: x = x.view(-1, self.towers, self.F_in) else: x = x.view(-1, 1, self.F_in).repeat(1, self.towers, 1) # propagate_type: (x: Tensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) out = torch.cat([x, out], dim=-1) outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)] out = torch.cat(outs, dim=1) return self.lin(out) def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor: h: Tensor = x_i # Dummy. if edge_attr is not None: edge_attr = self.edge_encoder(edge_attr) edge_attr = edge_attr.view(-1, 1, self.F_in) edge_attr = edge_attr.repeat(1, self.towers, 1) h = torch.cat([x_i, x_j, edge_attr], dim=-1) else: h = torch.cat([x_i, x_j], dim=-1) hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] return torch.stack(hs, dim=1) def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: outs = [aggr(inputs, index, dim_size) for aggr in self.aggregators] out = torch.cat(outs, dim=-1) deg = degree(index, dim_size, dtype=inputs.dtype).view(-1, 1, 1) outs = [scaler(out, deg, self.avg_deg) for scaler in self.scalers] return torch.cat(outs, dim=-1) def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, towers={self.towers}, dim={self.dim})') raise NotImplementedError class PNAConvSimple(MessagePassing): r"""The Principal Neighbourhood Aggregation graph convolution operator from the `"Principal Neighbourhood Aggregation for Graph Nets" `_ paper .. math:: \bigoplus = \underbrace{\begin{bmatrix}I \\ S(D, \alpha=1) \\ S(D, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}}, in: .. math:: X_i^{(t+1)} = U \left( \underset{(j,i) \in E}{\bigoplus} M \left(X_j^{(t)} \right) \right) where :math:`U` denote the MLP referred to with posttrans. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. aggregators (list of str): Set of aggregation function identifiers, namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"var"` and :obj:`"std"`. scalers: (list of str): Set of scaling function identifiers, namely :obj:`"identity"`, :obj:`"amplification"`, :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. deg (Tensor): Histogram of in-degrees of nodes in the training set, used by scalers to normalize. post_layers (int, optional): Number of transformation layers after aggregation (default: :obj:`1`). **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels: int, out_channels: int, aggregators: List[str], scalers: List[str], deg: Tensor, post_layers: int = 1, **kwargs): super(PNAConvSimple, self).__init__(aggr=None, node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators] self.scalers = [SCALERS[scale] for scale in scalers] self.F_in = in_channels self.F_out = self.out_channels deg = deg.to(torch.float) total_no_vertices = deg.sum() bin_degrees = torch.arange(len(deg)) self.avg_deg: Dict[str, float] = { 'lin': ((bin_degrees * deg).sum() / total_no_vertices).item(), 'log': (((bin_degrees + 1).log() * deg).sum() / total_no_vertices).item(), 'exp': ((bin_degrees.exp() * deg).sum() / total_no_vertices).item(), } in_channels = (len(aggregators) * len(scalers)) * self.F_in modules = [Linear(in_channels, self.F_out)] for _ in range(post_layers - 1): modules += [ReLU()] modules += [Linear(self.F_out, self.F_out)] self.post_nn = Sequential(*modules) self.reset_parameters() def reset_parameters(self): reset(self.post_nn) def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: # propagate_type: (x: Tensor) out = self.propagate(edge_index, x=x, size=None) return self.post_nn(out) def message(self, x_j: Tensor) -> Tensor: return x_j def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor: outs = [aggr(inputs, index, dim_size) for aggr in self.aggregators] out = torch.cat(outs, dim=-1) deg = degree(index, dim_size, dtype=inputs.dtype).view(-1, 1) outs = [scaler(out, deg, self.avg_deg) for scaler in self.scalers] return torch.cat(outs, dim=-1) def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}') raise NotImplementedError ================================================ FILE: models/pytorch_geometric/scalers.py ================================================ import torch from torch import Tensor from typing import Dict # Implemented with the help of Matthias Fey, author of PyTorch Geometric # For an example see https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pna.py def scale_identity(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): return src def scale_amplification(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): return src * (torch.log(deg + 1) / avg_deg['log']) def scale_attenuation(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): scale = avg_deg['log'] / torch.log(deg + 1) scale[deg == 0] = 1 return src * scale def scale_linear(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): return src * (deg / avg_deg['lin']) def scale_inverse_linear(src: Tensor, deg: Tensor, avg_deg: Dict[str, float]): scale = avg_deg['lin'] / deg scale[deg == 0] = 1 return src * scale SCALERS = { 'identity': scale_identity, 'amplification': scale_amplification, 'attenuation': scale_attenuation, 'linear': scale_linear, 'inverse_linear': scale_inverse_linear } ================================================ FILE: multitask_benchmark/README.md ================================================ # Multi-task benchmark Real world results ## Overview We provide the scripts for the generation and execution of the multi-task benchmark. - `dataset_generation` contains: - `graph_generation.py` with scripts to generate the various graphs and add randomness; - `graph_algorithms.py` with the implementation of many algorithms on graphs that can be used as labels; - `multitask_dataset.py` unifies the two files above generating and saving the benchmarks we used in the paper. - `util` contains: - preprocessing subroutines and loss functions (`util.py`); - general training and evaluation procedures (`train.py`). - `train` contains a script for each model which sets up the command line parameters and initiates the training procedure. This benchmark uses the PyTorch version of PNA (`../models/pytorch/pna`). Below you can find the instructions on how to create the dataset and run the models, these are also available in this [notebook](https://colab.research.google.com/drive/17NntHxoKQzpKmi8siMOLP9WfANlwbW8S?usp=sharing). ## Dependencies Install PyTorch from the [official website](https://pytorch.org/). The code was tested over PyTorch 1.4. Move to the source of the repository before running the following. Then install the other dependencies: ``` pip3 install -r multitask_benchmark/requirements.txt ``` ## Test run Generate the benchmark dataset (add `--extrapolation` for multiple test sets of different sizes): ``` python3 -m multitask_benchmark.datasets_generation.multitask_dataset ``` then run the training: ``` python3 -m multitask_benchmark.train.pna --variable --fixed --gru --lr=0.003 --weight_decay=1e-6 --dropout=0.0 --epochs=10000 --patience=1000 --variable_conv_layers=N/2 --fc_layers=3 --hidden=16 --towers=4 --aggregators="mean max min std" --scalers="identity amplification attenuation" --data=multitask_benchmark/data/multitask_dataset.pkl ``` The command above uses the hyperparameters tuned for the non-extrapolating dataset and the architecture outlined in the diagram below. For more details on the architecture, how the hyperparameters were tuned and the results collected refer to our [paper](https://arxiv.org/abs/2004.05718). ![architecture](images/architecture.png) ================================================ FILE: multitask_benchmark/datasets_generation/graph_algorithms.py ================================================ import math from queue import Queue import numpy as np def is_connected(A): """ :param A:np.array the adjacency matrix :return:bool whether the graph is connected or not """ for _ in range(int(1 + math.ceil(math.log2(A.shape[0])))): A = np.dot(A, A) return np.min(A) > 0 def identity(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return:F """ return F def first_neighbours(A): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the number of nodes reachable in 1 hop """ return np.sum(A > 0, axis=0) def second_neighbours(A): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the number of nodes reachable in no more than 2 hops """ A = A > 0.0 A = A + np.dot(A, A) np.fill_diagonal(A, 0) return np.sum(A > 0, axis=0) def kth_neighbours(A, k): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the number of nodes reachable in k hops """ A = A > 0.0 R = np.zeros(A.shape) for _ in range(k): R = np.dot(R, A) + A np.fill_diagonal(R, 0) return np.sum(R > 0, axis=0) def map_reduce_neighbourhood(A, F, f_reduce, f_map=None, hops=1, consider_itself=False): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, map its neighbourhood with f_map, and reduce it with f_reduce """ if f_map is not None: F = f_map(F) A = np.array(A) A = A > 0 R = np.zeros(A.shape) for _ in range(hops): R = np.dot(R, A) + A np.fill_diagonal(R, 1 if consider_itself else 0) R = R > 0 return np.array([f_reduce(F[R[i]]) for i in range(A.shape[0])]) def max_neighbourhood(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the maximum in its neighbourhood """ return map_reduce_neighbourhood(A, F, np.max, consider_itself=True) def min_neighbourhood(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the minimum in its neighbourhood """ return map_reduce_neighbourhood(A, F, np.min, consider_itself=True) def std_neighbourhood(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the standard deviation of its neighbourhood """ return map_reduce_neighbourhood(A, F, np.std, consider_itself=True) def mean_neighbourhood(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the mean of its neighbourhood """ return map_reduce_neighbourhood(A, F, np.mean, consider_itself=True) def local_maxima(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, whether it is the maximum in its neighbourhood """ return F == map_reduce_neighbourhood(A, F, np.max, consider_itself=True) def graph_laplacian(A): """ :param A:np.array the adjacency matrix :return: the laplacian of the adjacency matrix """ L = (A > 0) * -1 np.fill_diagonal(L, np.sum(A > 0, axis=0)) return L def graph_laplacian_features(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: the laplacian of the adjacency matrix multiplied by the features """ return np.matmul(graph_laplacian(A), F) def isomorphism(A1, A2, F1=None, F2=None): """ Takes two adjacency matrices (A1,A2) and (optionally) two lists of features. It uses Weisfeiler-Lehman algorithms, so false positives might arise :param A1: adj_matrix, N*N numpy matrix :param A2: adj_matrix, N*N numpy matrix :param F1: node_values, numpy array of size N :param F1: node_values, numpy array of size N :return: isomorphic: boolean which is false when the two graphs are not isomorphic, true when they probably are. """ N = A1.shape[0] if (F1 is None) ^ (F2 is None): raise ValueError("either both or none between F1,F2 must be defined.") if F1 is None: # Assign same initial value to each node F1 = np.ones(N, int) F2 = np.ones(N, int) else: if not np.array_equal(np.sort(F1), np.sort(F2)): return False if F1.dtype() != int: raise NotImplementedError('Still have to implement this') p = 1000000007 def mapping(F): return (F * 234 + 133) % 1000000007 def adjacency_hash(F): F = np.sort(F) b = 257 h = 0 for f in F: h = (b * h + f) % 1000000007 return h for i in range(N): F1 = map_reduce_neighbourhood(A1, F1, adjacency_hash, f_map=mapping, consider_itself=True, hops=1) F2 = map_reduce_neighbourhood(A2, F2, adjacency_hash, f_map=mapping, consider_itself=True, hops=1) if not np.array_equal(np.sort(F1), np.sort(F2)): return False return True def count_edges(A): """ :param A:np.array the adjacency matrix :return: the number of edges in the graph """ return np.sum(A) / 2 def is_eulerian_cyclable(A): """ :param A:np.array the adjacency matrix :return: whether the graph has an eulerian cycle """ return is_connected(A) and np.count_nonzero(first_neighbours(A) % 2 == 1) == 0 def is_eulerian_percorrible(A): """ :param A:np.array the adjacency matrix :return: whether the graph has an eulerian path """ return is_connected(A) and np.count_nonzero(first_neighbours(A) % 2 == 1) in [0, 2] def map_reduce_graph(A, F, f_reduce): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: the features of the nodes reduced by f_reduce """ return f_reduce(F) def mean_graph(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: the mean of the features """ return map_reduce_graph(A, F, np.mean) def max_graph(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: the maximum of the features """ return map_reduce_graph(A, F, np.max) def min_graph(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: the minimum of the features """ return map_reduce_graph(A, F, np.min) def std_graph(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: the standard deviation of the features """ return map_reduce_graph(A, F, np.std) def has_hamiltonian_cycle(A): """ :param A:np.array the adjacency matrix :return:bool whether the graph has an hamiltonian cycle """ A += np.transpose(A) A = A > 0 V = A.shape[0] def ham_cycle_loop(pos): if pos == V: if A[path[pos - 1]][path[0]]: return True else: return False for v in range(1, V): if A[path[pos - 1]][v] and not used[v]: path[pos] = v used[v] = True if ham_cycle_loop(pos + 1): return True path[pos] = -1 used[v] = False return False used = [False] * V path = [-1] * V path[0] = 0 return ham_cycle_loop(1) def all_pairs_shortest_paths(A, inf_sub=math.inf): """ :param A:np.array the adjacency matrix :param inf_sub: the placeholder value to use for pairs which are not connected :return:np.array all pairs shortest paths """ A = np.array(A) N = A.shape[0] for i in range(N): for j in range(N): if A[i][j] == 0: A[i][j] = math.inf if i == j: A[i][j] = 0 for k in range(N): for i in range(N): for j in range(N): A[i][j] = min(A[i][j], A[i][k] + A[k][j]) A = np.where(A == math.inf, inf_sub, A) return A def diameter(A): """ :param A:np.array the adjacency matrix :return: the diameter of the gra[h """ sum = np.sum(A) apsp = all_pairs_shortest_paths(A) apsp = np.where(apsp < sum + 1, apsp, -1) return np.max(apsp) def eccentricity(A): """ :param A:np.array the adjacency matrix :return: the eccentricity of the gra[h """ sum = np.sum(A) apsp = all_pairs_shortest_paths(A) apsp = np.where(apsp < sum + 1, apsp, -1) return np.max(apsp, axis=0) def sssp_predecessor(A, F): """ :param A:np.array the adjacency matrix :param F:np.array the nodes features :return: for each node, the best next step to reach the designated source """ assert (np.sum(F) == 1) assert (np.max(F) == 1) s = np.argmax(F) N = A.shape[0] P = np.zeros(A.shape) V = np.zeros(N) bfs = Queue() bfs.put(s) V[s] = 1 while not bfs.empty(): u = bfs.get() for v in range(N): if A[u][v] > 0 and V[v] == 0: V[v] = 1 P[v][u] = 1 bfs.put(v) return P def max_eigenvalue(A): """ :param A:np.array the adjacency matrix :return: the maximum eigenvalue of A since A is positive symmetric, all the eigenvalues are guaranteed to be real """ [W, _] = np.linalg.eig(A) return W[np.argmax(np.absolute(W))].real def max_eigenvalues(A, k): """ :param A:np.array the adjacency matrix :param k:int the number of eigenvalues to be selected :return: the k greatest (by absolute value) eigenvalues of A """ [W, _] = np.linalg.eig(A) values = W[sorted(range(len(W)), key=lambda x: -np.absolute(W[x]))[:k]] return values.real def max_absolute_eigenvalues(A, k): """ :param A:np.array the adjacency matrix :param k:int the number of eigenvalues to be selected :return: the absolute value of the k greatest (by absolute value) eigenvalues of A """ return np.absolute(max_eigenvalues(A, k)) def max_absolute_eigenvalues_laplacian(A, n): """ :param A:np.array the adjacency matrix :param k:int the number of eigenvalues to be selected :return: the absolute value of the k greatest (by absolute value) eigenvalues of the laplacian of A """ A = graph_laplacian(A) return np.absolute(max_eigenvalues(A, n)) def max_eigenvector(A): """ :param A:np.array the adjacency matrix :return: the maximum (by absolute value) eigenvector of A since A is positive symmetric, all the eigenvectors are guaranteed to be real """ [W, V] = np.linalg.eig(A) return V[:, np.argmax(np.absolute(W))].real def spectral_radius(A): """ :param A:np.array the adjacency matrix :return: the maximum (by absolute value) eigenvector of A since A is positive symmetric, all the eigenvectors are guaranteed to be real """ return np.abs(max_eigenvalue(A)) def page_rank(A, F=None, iter=64): """ :param A:np.array the adjacency matrix :param F:np.array with initial weights. If None, uniform initialization will happen. :param iter: log2 of length of power iteration :return: for each node, its pagerank """ # normalize A rows A = np.array(A) A /= A.sum(axis=1)[:, np.newaxis] # power iteration for _ in range(iter): A = np.matmul(A, A) # generate prior distribution if F is None: F = np.ones(A.shape[-1]) else: F = np.array(F) # normalize prior F /= np.sum(F) # compute limit distribution return np.matmul(F, A) def tsp_length(A, F=None): """ :param A:np.array the adjacency matrix :param F:np.array determining which nodes are to be visited. If None, all of them are. :return: the length of the Traveling Salesman Problem shortest solution """ A = all_pairs_shortest_paths(A) N = A.shape[0] if F is None: F = np.ones(N) targets = np.nonzero(F)[0] T = targets.shape[0] S = (1 << T) dp = np.zeros((S, T)) def popcount(x): b = 0 while x > 0: x &= x - 1 b += 1 return b msks = np.argsort(np.vectorize(popcount)(np.arange(S))) for i in range(T + 1): for j in range(T): if (1 << j) & msks[i] == 0: dp[msks[i]][j] = math.inf for i in range(T + 1, S): msk = msks[i] for u in range(T): if (1 << u) & msk == 0: dp[msk][u] = math.inf continue cost = math.inf for v in range(T): if v == u or (1 << v) & msk == 0: continue cost = min(cost, dp[msk ^ (1 << u)][v] + A[targets[v]][targets[u]]) dp[msk][u] = cost return np.min(dp[S - 1]) def get_nodes_labels(A, F): """ Takes the adjacency matrix and the list of nodes features (and a list of algorithms) and returns a set of labels for each node :param A: adj_matrix, N*N numpy matrix :param F: node_values, numpy array of size N :return: labels: KxN numpy matrix where K is the number of labels for each node """ labels = [identity(A, F), map_reduce_neighbourhood(A, F, np.mean, consider_itself=True), map_reduce_neighbourhood(A, F, np.max, consider_itself=True), map_reduce_neighbourhood(A, F, np.std, consider_itself=True), first_neighbours(A), second_neighbours(A), eccentricity(A)] return np.swapaxes(np.stack(labels), 0, 1) def get_graph_labels(A, F): """ Takes the adjacency matrix and the list of nodes features (and a list of algorithms) and returns a set of labels for the whole graph :param A: adj_matrix, N*N numpy matrix :param F: node_values, numpy array of size N :return: labels: numpy array of size K where K is the number of labels for the graph """ labels = [diameter(A)] return np.asarray(labels) ================================================ FILE: multitask_benchmark/datasets_generation/graph_generation.py ================================================ import numpy as np import random import networkx as nx import math import matplotlib.pyplot as plt # only required to plot from enum import Enum """ Generates random graphs of different types of a given size. Some of the graph are created using the NetworkX library, for more info see https://networkx.github.io/documentation/networkx-1.10/reference/generators.html """ class GraphType(Enum): RANDOM = 0 ERDOS_RENYI = 1 BARABASI_ALBERT = 2 GRID = 3 CAVEMAN = 5 TREE = 6 LADDER = 7 LINE = 8 STAR = 9 CATERPILLAR = 10 LOBSTER = 11 # probabilities of each type in case of random type MIXTURE = [(GraphType.ERDOS_RENYI, 0.2), (GraphType.BARABASI_ALBERT, 0.2), (GraphType.GRID, 0.05), (GraphType.CAVEMAN, 0.05), (GraphType.TREE, 0.15), (GraphType.LADDER, 0.05), (GraphType.LINE, 0.05), (GraphType.STAR, 0.05), (GraphType.CATERPILLAR, 0.1), (GraphType.LOBSTER, 0.1)] def erdos_renyi(N, degree, seed): """ Creates an Erdős-Rényi or binomial graph of size N with degree/N probability of edge creation """ return nx.fast_gnp_random_graph(N, degree / N, seed, directed=False) def barabasi_albert(N, degree, seed): """ Creates a random graph according to the Barabási–Albert preferential attachment model of size N and where nodes are atteched with degree edges """ return nx.barabasi_albert_graph(N, degree, seed) def grid(N): """ Creates a m x k 2d grid graph with N = m*k and m and k as close as possible """ m = 1 for i in range(1, int(math.sqrt(N)) + 1): if N % i == 0: m = i return nx.grid_2d_graph(m, N // m) def caveman(N): """ Creates a caveman graph of m cliques of size k, with m and k as close as possible """ m = 1 for i in range(1, int(math.sqrt(N)) + 1): if N % i == 0: m = i return nx.caveman_graph(m, N // m) def tree(N, seed): """ Creates a tree of size N with a power law degree distribution """ return nx.random_powerlaw_tree(N, seed=seed, tries=10000) def ladder(N): """ Creates a ladder graph of N nodes: two rows of N/2 nodes, with each pair connected by a single edge. In case N is odd another node is attached to the first one. """ G = nx.ladder_graph(N // 2) if N % 2 != 0: G.add_node(N - 1) G.add_edge(0, N - 1) return G def line(N): """ Creates a graph composed of N nodes in a line """ return nx.path_graph(N) def star(N): """ Creates a graph composed by one center node connected N-1 outer nodes """ return nx.star_graph(N - 1) def caterpillar(N, seed): """ Creates a random caterpillar graph with a backbone of size b (drawn from U[1, N)), and N − b pendent vertices uniformly connected to the backbone. """ np.random.seed(seed) B = np.random.randint(low=1, high=N) G = nx.empty_graph(N) for i in range(1, B): G.add_edge(i - 1, i) for i in range(B, N): G.add_edge(i, np.random.randint(B)) return G def lobster(N, seed): """ Creates a random Lobster graph with a backbone of size b (drawn from U[1, N)), and p (drawn from U[1, N − b ]) pendent vertices uniformly connected to the backbone, and additional N − b − p pendent vertices uniformly connected to the previous pendent vertices """ np.random.seed(seed) B = np.random.randint(low=1, high=N) F = np.random.randint(low=B + 1, high=N + 1) G = nx.empty_graph(N) for i in range(1, B): G.add_edge(i - 1, i) for i in range(B, F): G.add_edge(i, np.random.randint(B)) for i in range(F, N): G.add_edge(i, np.random.randint(low=B, high=F)) return G def randomize(A): """ Adds some randomness by toggling some edges without changing the expected number of edges of the graph """ BASE_P = 0.9 # e is the number of edges, r the number of missing edges N = A.shape[0] e = np.sum(A) / 2 r = N * (N - 1) / 2 - e # ep chance of an existing edge to remain, rp chance of another edge to appear if e <= r: ep = BASE_P rp = (1 - BASE_P) * e / r else: ep = BASE_P + (1 - BASE_P) * (e - r) / e rp = 1 - BASE_P array = np.random.uniform(size=(N, N), low=0.0, high=0.5) array = array + array.transpose() remaining = np.multiply(np.where(array < ep, 1, 0), A) appearing = np.multiply(np.multiply(np.where(array < rp, 1, 0), 1 - A), 1 - np.eye(N)) ans = np.add(remaining, appearing) # assert (np.all(np.multiply(ans, np.eye(N)) == np.zeros((N, N)))) # assert (np.all(ans >= 0)) # assert (np.all(ans <= 1)) # assert (np.all(ans == ans.transpose())) return ans def generate_graph(N, type=GraphType.RANDOM, seed=None, degree=None): """ Generates random graphs of different types of a given size. Note: - graph are undirected and without weights on edges - node values are sampled independently from U[0,1] :param N: number of nodes :param type: type chosen between the categories specified in GraphType enum :param seed: random seed :param degree: average degree of a node, only used in some graph types :return: adj_matrix: N*N numpy matrix node_values: numpy array of size N """ random.seed(seed) np.random.seed(seed) # sample which random type to use if type == GraphType.RANDOM: type = np.random.choice([t for (t, _) in MIXTURE], 1, p=[pr for (_, pr) in MIXTURE])[0] # generate the graph structure depending on the type if type == GraphType.ERDOS_RENYI: if degree == None: degree = random.random() * N G = erdos_renyi(N, degree, seed) elif type == GraphType.BARABASI_ALBERT: if degree == None: degree = int(random.random() * (N - 1)) + 1 G = barabasi_albert(N, degree, seed) elif type == GraphType.GRID: G = grid(N) elif type == GraphType.CAVEMAN: G = caveman(N) elif type == GraphType.TREE: G = tree(N, seed) elif type == GraphType.LADDER: G = ladder(N) elif type == GraphType.LINE: G = line(N) elif type == GraphType.STAR: G = star(N) elif type == GraphType.CATERPILLAR: G = caterpillar(N, seed) elif type == GraphType.LOBSTER: G = lobster(N, seed) else: print("Type not defined") return # generate adjacency matrix and nodes values nodes = list(G) random.shuffle(nodes) adj_matrix = nx.to_numpy_array(G, nodes) node_values = np.random.uniform(low=0, high=1, size=N) # randomization adj_matrix = randomize(adj_matrix) # draw the graph created # nx.draw(G, pos=nx.spring_layout(G)) # plt.draw() return adj_matrix, node_values, type if __name__ == '__main__': for i in range(100): adj_matrix, node_values = generate_graph(10, GraphType.RANDOM, seed=i) print(adj_matrix) ================================================ FILE: multitask_benchmark/datasets_generation/multitask_dataset.py ================================================ import argparse import os import pickle import numpy as np import torch from inspect import signature from tqdm import tqdm from . import graph_algorithms from .graph_generation import GraphType, generate_graph class DatasetMultitask: def __init__(self, n_graphs, N, seed, graph_type, get_nodes_labels, get_graph_labels, print_every, sssp, filename): self.adj = {} self.features = {} self.nodes_labels = {} self.graph_labels = {} def to_categorical(x, N): v = np.zeros(N) v[x] = 1 return v for dset in N.keys(): if dset not in n_graphs: n_graphs[dset] = n_graphs['default'] total_n_graphs = sum(n_graphs[dset]) set_adj = [[] for _ in n_graphs[dset]] set_features = [[] for _ in n_graphs[dset]] set_nodes_labels = [[] for _ in n_graphs[dset]] set_graph_labels = [[] for _ in n_graphs[dset]] t = tqdm(total=np.sum(n_graphs[dset]), desc=dset, leave=True, unit=' graphs') for batch, batch_size in enumerate(n_graphs[dset]): for i in range(batch_size): # generate a random graph of type graph_type and size N seed += 1 adj, features, type = generate_graph(N[dset][batch], graph_type, seed=seed) while np.min(np.max(adj, 0)) == 0.0: # remove graph with singleton nodes seed += 1 adj, features, _ = generate_graph(N[dset][batch], type, seed=seed) t.update(1) # make sure there are no self connection assert np.all( np.multiply(adj, np.eye(N[dset][batch])) == np.zeros((N[dset][batch], N[dset][batch]))) if sssp: # define the source node source_node = np.random.randint(0, N[dset][batch]) # compute the labels with graph_algorithms; if sssp add the sssp node_labels = get_nodes_labels(adj, features, graph_algorithms.all_pairs_shortest_paths(adj, 0)[source_node] if sssp else None) graph_labels = get_graph_labels(adj, features) if sssp: # add the 1-hot feature determining the starting node features = np.stack([to_categorical(source_node, N[dset][batch]), features], axis=1) set_adj[batch].append(adj) set_features[batch].append(features) set_nodes_labels[batch].append(node_labels) set_graph_labels[batch].append(graph_labels) t.close() self.adj[dset] = [torch.from_numpy(np.asarray(adjs)).float() for adjs in set_adj] self.features[dset] = [torch.from_numpy(np.asarray(fs)).float() for fs in set_features] self.nodes_labels[dset] = [torch.from_numpy(np.asarray(nls)).float() for nls in set_nodes_labels] self.graph_labels[dset] = [torch.from_numpy(np.asarray(gls)).float() for gls in set_graph_labels] self.save_as_pickle(filename) def save_as_pickle(self, filename): """" Saves the data into a pickle file at filename """ directory = os.path.dirname(filename) if not os.path.exists(directory): os.makedirs(directory) with open(filename, 'wb') as f: torch.save((self.adj, self.features, self.nodes_labels, self.graph_labels), f) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--out', type=str, default='./multitask_benchmark/data/multitask_dataset.pkl', help='Data path.') parser.add_argument('--seed', type=int, default=1234, help='Random seed.') parser.add_argument('--graph_type', type=str, default='RANDOM', help='Type of graphs in train set') parser.add_argument('--nodes_labels', nargs='+', default=["eccentricity", "graph_laplacian_features", "sssp"]) parser.add_argument('--graph_labels', nargs='+', default=["is_connected", "diameter", "spectral_radius"]) parser.add_argument('--extrapolation', action='store_true', default=False, help='Generated various test sets of dimensions larger than train and validation.') parser.add_argument('--print_every', type=int, default=20, help='') args = parser.parse_args() if 'sssp' in args.nodes_labels: sssp = True args.nodes_labels.remove('sssp') else: sssp = False # gets the functions of graph_algorithms from the specified datasets nodes_labels_algs = list(map(lambda s: getattr(graph_algorithms, s), args.nodes_labels)) graph_labels_algs = list(map(lambda s: getattr(graph_algorithms, s), args.graph_labels)) def get_nodes_labels(A, F, initial=None): labels = [] if initial is None else [initial] for f in nodes_labels_algs: params = signature(f).parameters labels.append(f(A, F) if 'F' in params else f(A)) return np.swapaxes(np.stack(labels), 0, 1) def get_graph_labels(A, F): labels = [] for f in graph_labels_algs: params = signature(f).parameters labels.append(f(A, F) if 'F' in params else f(A)) return np.asarray(labels).flatten() data = DatasetMultitask(n_graphs={'train': [512] * 10, 'val': [128] * 5, 'default': [256] * 5}, N={**{'train': range(15, 25), 'val': range(15, 25)}, **( {'test-(20,25)': range(20, 25), 'test-(25,30)': range(25, 30), 'test-(30,35)': range(30, 35), 'test-(35,40)': range(35, 40), 'test-(40,45)': range(40, 45), 'test-(45,50)': range(45, 50), 'test-(60,65)': range(60, 65), 'test-(75,80)': range(75, 80), 'test-(95,100)': range(95, 100)} if args.extrapolation else {'test': range(15, 25)})}, seed=args.seed, graph_type=getattr(GraphType, args.graph_type), get_nodes_labels=get_nodes_labels, get_graph_labels=get_graph_labels, print_every=args.print_every, sssp=sssp, filename=args.out) data.save_as_pickle(args.out) ================================================ FILE: multitask_benchmark/requirements.txt ================================================ numpy networkx matplotlib torch ================================================ FILE: multitask_benchmark/train/gat.py ================================================ from __future__ import division from __future__ import print_function from models.pytorch.gat.layer import GATLayer from multitask_benchmark.util.train import execute_train, build_arg_parser # Training settings parser = build_arg_parser() parser.add_argument('--nheads', type=int, default=4, help='Number of attentions heads.') parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.') args = parser.parse_args() execute_train(gnn_args=dict(nfeat=None, nhid=args.hidden, nodes_out=None, graph_out=None, dropout=args.dropout, device=None, first_conv_descr=dict(layer_type=GATLayer, args=dict( nheads=args.nheads, alpha=args.alpha )), middle_conv_descr=dict(layer_type=GATLayer, args=dict( nheads=args.nheads, alpha=args.alpha )), fc_layers=args.fc_layers, conv_layers=args.conv_layers, skip=args.skip, gru=args.gru, fixed=args.fixed, variable=args.variable), args=args) ================================================ FILE: multitask_benchmark/train/gcn.py ================================================ from __future__ import division from __future__ import print_function from models.pytorch.gcn.layer import GCNLayer from multitask_benchmark.util.train import execute_train, build_arg_parser # Training settings parser = build_arg_parser() args = parser.parse_args() execute_train(gnn_args=dict(nfeat=None, nhid=args.hidden, nodes_out=None, graph_out=None, dropout=args.dropout, device=None, first_conv_descr=dict(layer_type=GCNLayer, args=dict()), middle_conv_descr=dict(layer_type=GCNLayer, args=dict()), fc_layers=args.fc_layers, conv_layers=args.conv_layers, skip=args.skip, gru=args.gru, fixed=args.fixed, variable=args.variable), args=args) ================================================ FILE: multitask_benchmark/train/gin.py ================================================ from __future__ import division from __future__ import print_function from models.pytorch.gin.layer import GINLayer from multitask_benchmark.util.train import execute_train, build_arg_parser # Training settings parser = build_arg_parser() parser.add_argument('--gin_fc_layers', type=int, default=2, help='Number of fully connected layers after the aggregation.') args = parser.parse_args() execute_train(gnn_args=dict(nfeat=None, nhid=args.hidden, nodes_out=None, graph_out=None, dropout=args.dropout, device=None, first_conv_descr=dict(layer_type=GINLayer, args=dict(fc_layers=args.gin_fc_layers)), middle_conv_descr=dict(layer_type=GINLayer, args=dict(fc_layers=args.gin_fc_layers)), fc_layers=args.fc_layers, conv_layers=args.conv_layers, skip=args.skip, gru=args.gru, fixed=args.fixed, variable=args.variable), args=args) ================================================ FILE: multitask_benchmark/train/mpnn.py ================================================ from __future__ import division from __future__ import print_function from models.pytorch.pna.layer import PNALayer from multitask_benchmark.util.train import execute_train, build_arg_parser # Training settings parser = build_arg_parser() parser.add_argument('--self_loop', action='store_true', default=False, help='Whether to add self loops in aggregators') parser.add_argument('--towers', type=int, default=4, help='Number of towers in MPNN layers') parser.add_argument('--aggregation', type=str, default='sum', help='Type of aggregation') parser.add_argument('--pretrans_layers', type=int, default=1, help='Number of MLP layers before aggregation') parser.add_argument('--posttrans_layers', type=int, default=1, help='Number of MLP layers after aggregation') args = parser.parse_args() # The MPNNs can be considered a particular case of PNA networks with a single aggregator and no scalers (identity) execute_train(gnn_args=dict(nfeat=None, nhid=args.hidden, nodes_out=None, graph_out=None, dropout=args.dropout, device=None, first_conv_descr=dict(layer_type=PNALayer, args=dict( aggregators=[args.aggregation], scalers=['identity'], avg_d=None, towers=args.towers, self_loop=args.self_loop, divide_input=False, pretrans_layers=args.pretrans_layers, posttrans_layers=args.posttrans_layers )), middle_conv_descr=dict(layer_type=PNALayer, args=dict( aggregators=[args.aggregation], scalers=['identity'], avg_d=None, towers=args.towers, self_loop=args.self_loop, divide_input=True, pretrans_layers=args.pretrans_layers, posttrans_layers=args.posttrans_layers )), fc_layers=args.fc_layers, conv_layers=args.conv_layers, skip=args.skip, gru=args.gru, fixed=args.fixed, variable=args.variable), args=args) ================================================ FILE: multitask_benchmark/train/pna.py ================================================ from __future__ import division from __future__ import print_function from models.pytorch.pna.layer import PNALayer from multitask_benchmark.util.train import execute_train, build_arg_parser # Training settings parser = build_arg_parser() parser.add_argument('--self_loop', action='store_true', default=False, help='Whether to add self loops in aggregators') parser.add_argument('--aggregators', type=str, default='mean max min std', help='Aggregators to use') parser.add_argument('--scalers', type=str, default='identity amplification attenuation', help='Scalers to use') parser.add_argument('--towers', type=int, default=4, help='Number of towers in PNA layers') parser.add_argument('--pretrans_layers', type=int, default=1, help='Number of MLP layers before aggregation') parser.add_argument('--posttrans_layers', type=int, default=1, help='Number of MLP layers after aggregation') args = parser.parse_args() execute_train(gnn_args=dict(nfeat=None, nhid=args.hidden, nodes_out=None, graph_out=None, dropout=args.dropout, device=None, first_conv_descr=dict(layer_type=PNALayer, args=dict( aggregators=args.aggregators.split(), scalers=args.scalers.split(), avg_d=None, towers=args.towers, self_loop=args.self_loop, divide_input=False, pretrans_layers=args.pretrans_layers, posttrans_layers=args.posttrans_layers )), middle_conv_descr=dict(layer_type=PNALayer, args=dict( aggregators=args.aggregators.split(), scalers=args.scalers.split(), avg_d=None, towers=args.towers, self_loop=args.self_loop, divide_input=True, pretrans_layers=args.pretrans_layers, posttrans_layers=args.posttrans_layers )), fc_layers=args.fc_layers, conv_layers=args.conv_layers, skip=args.skip, gru=args.gru, fixed=args.fixed, variable=args.variable), args=args) ================================================ FILE: multitask_benchmark/util/train.py ================================================ from __future__ import division from __future__ import print_function import argparse import os import sys import time from types import SimpleNamespace import math import numpy as np import torch import torch.optim as optim from tqdm import tqdm from models.pytorch.gnn_framework import GNN from multitask_benchmark.util.util import load_dataset, total_loss, total_loss_multiple_batches, \ specific_loss_multiple_batches def build_arg_parser(): """ :return: argparse.ArgumentParser() filled with the standard arguments for a training session. Might need to be enhanced for some train_scripts. """ parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default='../../data/multitask_dataset.pkl', help='Data path.') parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') parser.add_argument('--only_nodes', action='store_true', default=False, help='Evaluate only nodes labels.') parser.add_argument('--only_graph', action='store_true', default=False, help='Evaluate only graph labels.') parser.add_argument('--seed', type=int, default=42, help='Random seed.') parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.') parser.add_argument('--lr', type=float, default=0.003, help='Initial learning rate.') parser.add_argument('--weight_decay', type=float, default=1e-6, help='Weight decay (L2 loss on parameters).') parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.') parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate (1 - keep probability).') parser.add_argument('--patience', type=int, default=1000, help='Patience') parser.add_argument('--conv_layers', type=int, default=None, help='Graph convolutions') parser.add_argument('--variable_conv_layers', type=str, default='N', help='Graph convolutions function name') parser.add_argument('--fc_layers', type=int, default=3, help='Fully connected layers in readout') parser.add_argument('--loss', type=str, default='mse', help='Loss function to use.') parser.add_argument('--print_every', type=int, default=50, help='Print training results every') parser.add_argument('--final_activation', type=str, default='LeakyReLu', help='final activation in both FC layers for nodes and S2S for Graph') parser.add_argument('--skip', action='store_true', default=False, help='Whether to use the model with skip connections.') parser.add_argument('--gru', action='store_true', default=False, help='Whether to use a GRU in the update function of the layers.') parser.add_argument('--fixed', action='store_true', default=False, help='Whether to use the model with fixed middle convolutions.') parser.add_argument('--variable', action='store_true', default=False, help='Whether to have a variable number of comvolutional layers.') return parser # map from names (as passed as parameters) to function determining number of convolutional layers at runtime VARIABLE_LAYERS_FUNCTIONS = { 'N': lambda adj: adj.shape[1], 'N/2': lambda adj: adj.shape[1] // 2, '4log2N': lambda adj: int(4 * math.log2(adj.shape[1])), '2log2N': lambda adj: int(2 * math.log2(adj.shape[1])), '3sqrtN': lambda adj: int(3 * math.sqrt(adj.shape[1])) } def execute_train(gnn_args, args): """ :param gnn_args: the description of the model to be trained (expressed as arguments for GNN.__init__) :param args: the parameters of the training session """ args.cuda = not args.no_cuda and torch.cuda.is_available() np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) device = 'cuda' if args.cuda else 'cpu' print('Using device:', device) # load data adj, features, node_labels, graph_labels = load_dataset(args.data, args.loss, args.only_nodes, args.only_graph, print_baseline=True) # model and optimizer gnn_args = SimpleNamespace(**gnn_args) # compute avg_d on the training set if 'avg_d' in gnn_args.first_conv_descr['args'] or 'avg_d' in gnn_args.middle_conv_descr['args']: dlist = [torch.sum(A, dim=-1) for A in adj['train']] avg_d = dict(lin=sum([torch.mean(D) for D in dlist]) / len(dlist), exp=sum([torch.mean(torch.exp(torch.div(1, D)) - 1) for D in dlist]) / len(dlist), log=sum([torch.mean(torch.log(D + 1)) for D in dlist]) / len(dlist)) if 'avg_d' in gnn_args.first_conv_descr['args']: gnn_args.first_conv_descr['args']['avg_d'] = avg_d if 'avg_d' in gnn_args.middle_conv_descr['args']: gnn_args.middle_conv_descr['args']['avg_d'] = avg_d gnn_args.device = device gnn_args.nfeat = features['train'][0].shape[2] gnn_args.nodes_out = node_labels['train'][0].shape[-1] gnn_args.graph_out = graph_labels['train'][0].shape[-1] if gnn_args.variable: assert gnn_args.conv_layers is None, "If model is variable, you shouldn't specify conv_layers (maybe you " \ "meant variable_conv_layers?) " else: assert gnn_args.conv_layers is not None, "If the model is not variable, you should specify conv_layers" gnn_args.conv_layers = VARIABLE_LAYERS_FUNCTIONS[ args.variable_conv_layers] if gnn_args.variable else args.conv_layers model = GNN(**vars(gnn_args)) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Total params", pytorch_total_params) def move_cuda(dset): assert args.cuda, "Cannot move dataset on CUDA, running on cpu" if features[dset][0].is_cuda: # already on CUDA return features[dset] = [x.cuda() for x in features[dset]] adj[dset] = [x.cuda() for x in adj[dset]] node_labels[dset] = [x.cuda() for x in node_labels[dset]] graph_labels[dset] = [x.cuda() for x in graph_labels[dset]] if args.cuda: model.cuda() # move train, val to CUDA (delay moving test until needed) move_cuda('train') move_cuda('val') def train(epoch): """ Execute a single epoch of the training loop :param epoch:int the number of the epoch being performed (0-indexed) """ t = time.time() # train step model.train() for batch in range(len(adj['train'])): optimizer.zero_grad() output = model(features['train'][batch], adj['train'][batch]) loss_train = total_loss(output, (node_labels['train'][batch], graph_labels['train'][batch]), loss=args.loss, only_nodes=args.only_nodes, only_graph=args.only_graph) loss_train.backward() optimizer.step() # validation epoch model.eval() output_zip = [model(features['val'][batch], adj['val'][batch]) for batch in range(len(adj['val']))] output = ([x[0] for x in output_zip], [x[1] for x in output_zip]) loss_val = total_loss_multiple_batches(output, (node_labels['val'], graph_labels['val']), loss=args.loss, only_nodes=args.only_nodes, only_graph=args.only_graph) return loss_train.data.item(), loss_val def compute_test(): """ Evaluate the current model on all the sets of the dataset, printing results. This procedure is destructive on datasets. """ model.eval() sets = list(features.keys()) for dset in sets: # move data on CUDA if not already on it if args.cuda: move_cuda(dset) output_zip = [model(features[dset][batch], adj[dset][batch]) for batch in range(len(adj[dset]))] output = ([x[0] for x in output_zip], [x[1] for x in output_zip]) loss_test = total_loss_multiple_batches(output, (node_labels[dset], graph_labels[dset]), loss=args.loss, only_nodes=args.only_nodes, only_graph=args.only_graph) print("Test set results ", dset, ": loss= {:.4f}".format(loss_test)) print(dset, ": ", specific_loss_multiple_batches(output, (node_labels[dset], graph_labels[dset]), loss=args.loss, only_nodes=args.only_nodes, only_graph=args.only_graph)) # free unnecessary data del output_zip del output del loss_test del features[dset] del adj[dset] del node_labels[dset] del graph_labels[dset] torch.cuda.empty_cache() sys.stdout.flush() # Train model t_total = time.time() loss_values = [] bad_counter = 0 best = args.epochs + 1 best_epoch = -1 sys.stdout.flush() with tqdm(range(args.epochs), leave=True, unit='epoch') as t: for epoch in t: loss_train, loss_val = train(epoch) loss_values.append(loss_val) t.set_description('loss.train: {:.4f}, loss.val: {:.4f}'.format(loss_train, loss_val)) if loss_values[-1] < best: # save current model torch.save(model.state_dict(), '{}.pkl'.format(epoch)) # remove previous model if best_epoch >= 0: os.remove('{}.pkl'.format(best_epoch)) # update training variables best = loss_values[-1] best_epoch = epoch bad_counter = 0 else: bad_counter += 1 if bad_counter == args.patience: print('Early stop at epoch {} (no improvement in last {} epochs)'.format(epoch + 1, bad_counter)) break print("Optimization Finished!") print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) # Restore best model print('Loading {}th epoch'.format(best_epoch + 1)) model.load_state_dict(torch.load('{}.pkl'.format(best_epoch))) # Testing with torch.no_grad(): compute_test() ================================================ FILE: multitask_benchmark/util/util.py ================================================ from __future__ import division from __future__ import print_function import torch import torch.nn.functional as F def load_dataset(data_path, loss, only_nodes, only_graph, print_baseline=True): with open(data_path, 'rb') as f: (adj, features, node_labels, graph_labels) = torch.load(f) # normalize labels max_node_labels = torch.cat([nls.max(0)[0].max(0)[0].unsqueeze(0) for nls in node_labels['train']]).max(0)[0] max_graph_labels = torch.cat([gls.max(0)[0].unsqueeze(0) for gls in graph_labels['train']]).max(0)[0] for dset in node_labels.keys(): node_labels[dset] = [nls / max_node_labels for nls in node_labels[dset]] graph_labels[dset] = [gls / max_graph_labels for gls in graph_labels[dset]] if print_baseline: # calculate baseline mean_node_labels = torch.cat([nls.mean(0).mean(0).unsqueeze(0) for nls in node_labels['train']]).mean(0) mean_graph_labels = torch.cat([gls.mean(0).unsqueeze(0) for gls in graph_labels['train']]).mean(0) for dset in node_labels.keys(): if dset not in ['train', 'val']: baseline_nodes = [mean_node_labels.repeat(list(nls.shape[0:-1]) + [1]) for nls in node_labels[dset]] baseline_graph = [mean_graph_labels.repeat([gls.shape[0], 1]) for gls in graph_labels[dset]] print("Baseline loss ", dset, specific_loss_multiple_batches((baseline_nodes, baseline_graph), (node_labels[dset], graph_labels[dset]), loss=loss, only_nodes=only_nodes, only_graph=only_graph)) return adj, features, node_labels, graph_labels def get_loss(loss, output, target): if loss == "mse": return F.mse_loss(output, target) elif loss == "cross_entropy": if len(output.shape) > 2: (B, N, _) = output.shape output = output.reshape((B * N, -1)) target = target.reshape((B * N, -1)) _, target = target.max(dim=1) return F.cross_entropy(output, target) else: print("Error: loss function not supported") def total_loss(output, target, loss='mse', only_nodes=False, only_graph=False): """ returns the average of the average losses of each task """ assert not (only_nodes and only_graph) if only_nodes: nodes_loss = get_loss(loss, output[0], target[0]) return nodes_loss elif only_graph: graph_loss = get_loss(loss, output[1], target[1]) return graph_loss nodes_loss = get_loss(loss, output[0], target[0]) graph_loss = get_loss(loss, output[1], target[1]) weighted_average = (nodes_loss * output[0].shape[-1] + graph_loss * output[1].shape[-1]) / ( output[0].shape[-1] + output[1].shape[-1]) return weighted_average def total_loss_multiple_batches(output, target, loss='mse', only_nodes=False, only_graph=False): """ returns the average of the average losses of each task over all batches, batches are weighted equally regardless of their cardinality or graph size """ n_batches = len(output[0]) return sum([total_loss((output[0][batch], output[1][batch]), (target[0][batch], target[1][batch]), loss, only_nodes, only_graph).data.item() for batch in range(n_batches)]) / n_batches def specific_loss(output, target, loss='mse', only_nodes=False, only_graph=False): """ returns the average loss for each task """ assert not (only_nodes and only_graph) n_nodes_labels = output[0].shape[-1] if not only_graph else 0 n_graph_labels = output[1].shape[-1] if not only_nodes else 0 if only_nodes: nodes_loss = [get_loss(loss, output[0][:, :, k], target[0][:, :, k]).item() for k in range(n_nodes_labels)] return nodes_loss elif only_graph: graph_loss = [get_loss(loss, output[1][:, k], target[1][:, k]).item() for k in range(n_graph_labels)] return graph_loss nodes_loss = [get_loss(loss, output[0][:, :, k], target[0][:, :, k]).item() for k in range(n_nodes_labels)] graph_loss = [get_loss(loss, output[1][:, k], target[1][:, k]).item() for k in range(n_graph_labels)] return nodes_loss + graph_loss def specific_loss_multiple_batches(output, target, loss='mse', only_nodes=False, only_graph=False): """ returns the average loss over all batches for each task, batches are weighted equally regardless of their cardinality or graph size """ assert not (only_nodes and only_graph) n_batches = len(output[0]) classes = (output[0][0].shape[-1] if not only_graph else 0) + (output[1][0].shape[-1] if not only_nodes else 0) sum_losses = [0] * classes for batch in range(n_batches): spec_loss = specific_loss((output[0][batch], output[1][batch]), (target[0][batch], target[1][batch]), loss, only_nodes, only_graph) for par in range(classes): sum_losses[par] += spec_loss[par] return [sum_loss / n_batches for sum_loss in sum_losses] ================================================ FILE: realworld_benchmark/README.md ================================================ # Real-world benchmarks Real world results ## Overview We provide the scripts for the download and execution of the real-world benchmarks we used. Many scripts in this directory were taken directly from or inspired by "Benchmarking GNNs" by Dwivedi _et al._ refer to their [code](https://github.com/graphdeeplearning/benchmarking-gnns) and [paper](https://arxiv.org/abs/2003.00982) for more details on their work. The graph classification benchmark MolHIV comes from the [Open Graph Benchmark](https://ogb.stanford.edu/). - `configs` contains .json configuration files for the various datasets; - `data` contains scripts to download the datasets; - `nets` contains the architectures that were used with the PNA in the benchmarks; - `train` contains the training scripts. These benchmarks use the DGL version of PNA (`../models/dgl`) with the MolHIV model using the *simple* layer architecture. Below you can find the instructions on how to download the datasets and run the models. You can run these scripts directly in this [notebook](https://colab.research.google.com/drive/1RnV4MBjCl98eubAGpEF-eXdAW5mTP3h3?usp=sharing). ## Test run ### Benchmark Setup [Follow these instructions](./docs/setup.md) to install the benchmark and setup the environment. ### Run model training ``` # at the root of the repo cd realworld_benchmark python { main_molecules.py | main_superpixels.py } [--param=value ...] --dataset { ZINC | MNIST | CIFAR10 } --gpu_id gpu_id --config config_file ``` ## Tuned hyperparameters You can find below the hyperparameters we used for our experiments. In general, the depth of the architectures was not changed while the width was adjusted to keep the total number of parameters of the model between 100k and 110k as done in "Benchmarking GNNs" to ensure a fair comparison of the architectures. Refer to our [paper](https://arxiv.org/abs/2004.05718) for an interpretation of the results. ``` For OGB leaderboard (hyperparameters taken from the DGN model - 300k parameters): python -m main_HIV --weight_decay=3e-6 --L=4 --hidden_dim=80 --out_dim=80 --residual=True --readout=mean --in_feat_dropout=0.0 --dropout=0.3 --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --dataset HIV --gpu_id 0 --config "configs/molecules_graph_classification_PNA_HIV.json" --epochs=200 --init_lr=0.01 --lr_reduce_factor=0.5 --lr_schedule_patience=20 --min_lr=0.0001 For the leaderboard (2nd version of the datasets - 400/500k parameters) # ZINC PNA: python main_molecules.py --weight_decay=3e-6 --L=16 --hidden_dim=70 --out_dim=70 --residual=True --edge_feat=True --edge_dim=40 --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --towers=5 --pretrans_layers=1 --posttrans_layers=1 --divide_input_first=True --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=20 MPNN (sum/max): python main_molecules.py --weight_decay=3e-6 --L=16 --hidden_dim=110 --out_dim=110 --residual=True --edge_feat=True --edge_dim=40 --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="sum"/"max" --scalers="identity" --towers=5 --pretrans_layers=1 --posttrans_layers=1 --divide_input_first=True --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=20 For the paper (1st version of the datasets - 100k parameters) --- PNA --- # ZINC python main_molecules.py --weight_decay=3e-6 --L=4 --hidden_dim=75 --out_dim=70 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --towers=5 --divide_input_first=False --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=5 python main_molecules.py --weight_decay=3e-6 --L=4 --hidden_dim=70 --out_dim=60 --residual=True --edge_feat=True --edge_dim=50 --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --towers=5 --pretrans_layers=1 --posttrans_layers=1 --divide_input_first=True --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=20 # CIFAR10 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=75 --out_dim=70 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.1 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --towers=5 --divide_input_first=True --divide_input_last=True --dataset CIFAR10 --gpu_id 0 --config "configs/superpixels_graph_classification_pna_CIFAR10.json" --lr_schedule_patience=5 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=75 --out_dim=70 --residual=True --edge_feat=True --edge_dim=50 --readout=sum --in_feat_dropout=0.0 --dropout=0.3 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --towers=5 --divide_input_first=True --divide_input_last=True --dataset CIFAR10 --gpu_id 0 --config "configs/superpixels_graph_classification_pna_CIFAR10.json" --lr_schedule_patience=5 # MNIST python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=75 --out_dim=70 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.1 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --towers=5 --divide_input_first=True --divide_input_last=True --dataset MNIST --gpu_id 0 --config "configs/superpixels_graph_classification_pna_MNIST.json" --lr_schedule_patience=5 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=75 --out_dim=70 --residual=True --edge_feat=True --edge_dim=50 --readout=sum --in_feat_dropout=0.0 --dropout=0.3 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --towers=5 --divide_input_first=True --divide_input_last=True --dataset MNIST --gpu_id 0 --config "configs/superpixels_graph_classification_pna_MNIST.json" --lr_schedule_patience=5 --- PNA (no scalers) --- # ZINC python main_molecules.py --weight_decay=3e-6 --L=4 --hidden_dim=95 --out_dim=90 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=5 python main_molecules.py --weight_decay=3e-6 --L=4 --hidden_dim=90 --out_dim=80 --residual=True --edge_feat=True --edge_dim=50 --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity" --towers=5 --pretrans_layers=1 --posttrans_layers=1 --divide_input_first=True --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=20 # CIFAR10 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=95 --out_dim=90 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.1 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset CIFAR10 --gpu_id 0 --config "configs/superpixels_graph_classification_pna_CIFAR10.json" --lr_schedule_patience=5 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=95 --out_dim=90 --residual=True --edge_feat=True --edge_dim=50 --readout=sum --in_feat_dropout=0.0 --dropout=0.3 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset CIFAR10 --gpu_id 0 --config "configs/superpixels_graph_classification_pna_CIFAR10.json" --lr_schedule_patience=5 # MNIST python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=95 --out_dim=90 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.1 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset MNIST --gpu_id 0 --config "configs/superpixels_graph_classification_pna_MNIST.json" --lr_schedule_patience=5 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=95 --out_dim=90 --residual=True --edge_feat=True --edge_dim=50 --readout=sum --in_feat_dropout=0.0 --dropout=0.3 --graph_norm=True --batch_norm=True --aggregators="mean max min std" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset MNIST --gpu_id 0 --config "configs/superpixels_graph_classification_pna_MNIST.json" --lr_schedule_patience=5 --- MPNN (sum/max) --- # ZINC python main_molecules.py --weight_decay=1e-5 --L=4 --hidden_dim=110 --out_dim=80 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="sum"/"max" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=5 python main_molecules.py --weight_decay=3e-6 --L=4 --hidden_dim=100 --out_dim=70 --residual=True --edge_dim=50 --edge_feat=True --readout=sum --in_feat_dropout=0.0 --dropout=0.0 --graph_norm=True --batch_norm=True --aggregators="sum"/"max" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset ZINC --gpu_id 0 --config "configs/molecules_graph_regression_pna_ZINC.json" --lr_schedule_patience=20 # CIFAR10 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=110 --out_dim=90 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.2 --graph_norm=True --batch_norm=True --aggregators="sum"/"max" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset CIFAR10 --gpu_id 0 --config "configs/superpixels_graph_classification_pna_CIFAR10.json" --lr_schedule_patience=5 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=110 --out_dim=90 --residual=True --edge_feat=True --edge_dim=20 --readout=sum --in_feat_dropout=0.0 --dropout=0.2 --graph_norm=True --batch_norm=True --aggregators="sum"/"max" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset CIFAR10 --gpu_id 0 --config "configs/superpixels_graph_classification_pna_CIFAR10.json" --lr_schedule_patience=5 # MNIST python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=110 --out_dim=90 --residual=True --edge_feat=False --readout=sum --in_feat_dropout=0.0 --dropout=0.2 --graph_norm=True --batch_norm=True --aggregators="sum"/"max" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset MNIST --gpu_id 0 --config "configs/superpixels_graph_classification_pna_MNIST.json" --lr_schedule_patience=5 python main_superpixels.py --weight_decay=3e-6 --L=4 --hidden_dim=110 --out_dim=90 --residual=True --edge_feat=True --edge_dim=20 --readout=sum --in_feat_dropout=0.0 --dropout=0.2 --graph_norm=True --batch_norm=True --aggregators="sum"/"max" --scalers="identity" --towers=5 --divide_input_first=True --divide_input_last=True --dataset MNIST --gpu_id 0 --config "configs/superpixels_graph_classification_pna_MNIST.json" --lr_schedule_patience=5 ``` alternatively, for OGB leaderboard, run the following scripts in the [DGN](https://github.com/Saro00/DGN) repository: ``` # MolHIV python -m main_HIV --weight_decay=3e-6 --L=4 --hidden_dim=80 --out_dim=80 --residual=True --readout=mean --in_feat_dropout=0.0 --dropout=0.3 --batch_norm=True --aggregators="mean max min std" --scalers="identity amplification attenuation" --dataset HIV --config "configs/molecules_graph_classification_DGN_HIV.json" --epochs=200 --init_lr=0.01 --lr_reduce_factor=0.5 --lr_schedule_patience=20 --min_lr=0.0001 # MolPCBA python main_PCBA.py --type_net="complex" --batch_size=512 --lap_norm="none" --weight_decay=3e-6 --L=4 --hidden_dim=510 --out_dim=510 --residual=True --edge_feat=True --readout=sum --graph_norm=True --batch_norm=True --aggregators="mean sum max" --scalers="identity" --config "configs/molecules_graph_classification_DGN_PCBA.json" --lr_schedule_patience=4 --towers=5 --dropout=0.2 --init_lr=0.0005 --min_lr=0.00002 --edge_dim=16 --lr_reduce_factor=0.8 ``` ================================================ FILE: realworld_benchmark/configs/molecules_graph_classification_PNA_HIV.json ================================================ { "gpu": { "use": true, "id": 0 }, "model": "PNA", "dataset": "HIV", "params": { "seed": 41, "epochs": 200, "batch_size": 128, "init_lr": 0.01, "lr_reduce_factor": 0.5, "lr_schedule_patience": 20, "min_lr": 1e-4, "weight_decay": 3e-6, "print_epoch_interval": 5, "max_time": 48 }, "net_params": { "L": 4, "hidden_dim": 70, "out_dim": 70, "residual": true, "readout": "mean", "in_feat_dropout": 0.0, "dropout": 0.3, "batch_norm": true, "aggregators": "mean max min std", "scalers": "identity amplification attenuation", "posttrans_layers" : 1 } } ================================================ FILE: realworld_benchmark/configs/molecules_graph_regression_pna_ZINC.json ================================================ { "gpu": { "use": true, "id": 0 }, "model": "PNA", "dataset": "ZINC", "out_dir": "out/molecules_graph_regression/", "params": { "seed": 41, "epochs": 1000, "batch_size": 128, "init_lr": 0.001, "lr_reduce_factor": 0.5, "lr_schedule_patience": 5, "min_lr": 1e-5, "weight_decay": 3e-6, "print_epoch_interval": 5, "max_time": 48 }, "net_params": { "L": 4, "hidden_dim": 75, "out_dim": 70, "residual": true, "edge_feat": false, "readout": "sum", "in_feat_dropout": 0.0, "dropout": 0.0, "graph_norm": true, "batch_norm": true, "aggregators": "mean max min std", "scalers": "identity amplification attenuation", "towers": 5, "divide_input_first": false, "divide_input_last": true, "gru": false, "edge_dim": 0, "pretrans_layers" : 1, "posttrans_layers" : 1 } } ================================================ FILE: realworld_benchmark/configs/superpixels_graph_classification_pna_CIFAR10.json ================================================ { "gpu": { "use": true, "id": 0 }, "model": "PNA", "dataset": "CIFAR10", "out_dir": "out/superpixels_graph_classification/", "params": { "seed": 41, "epochs": 1000, "batch_size": 128, "init_lr": 0.001, "lr_reduce_factor": 0.5, "lr_schedule_patience": 5, "min_lr": 1e-5, "weight_decay": 3e-6, "print_epoch_interval": 5, "max_time": 48 }, "net_params": { "L": 4, "hidden_dim": 75, "out_dim": 70, "residual": true, "edge_feat": false, "readout": "sum", "in_feat_dropout": 0.0, "dropout": 0.0, "graph_norm": true, "batch_norm": true, "aggregators": "mean max min std", "scalers": "identity amplification attenuation", "towers": 5, "divide_input_first": true, "divide_input_last": false, "gru": false, "edge_dim": 0, "pretrans_layers" : 1, "posttrans_layers" : 1 } } ================================================ FILE: realworld_benchmark/configs/superpixels_graph_classification_pna_MNIST.json ================================================ { "gpu": { "use": true, "id": 0 }, "model": "PNA", "dataset": "MNIST", "out_dir": "out/superpixels_graph_classification/", "params": { "seed": 41, "epochs": 1000, "batch_size": 128, "init_lr": 0.001, "lr_reduce_factor": 0.5, "lr_schedule_patience": 5, "min_lr": 1e-5, "weight_decay": 3e-6, "print_epoch_interval": 5, "max_time": 48 }, "net_params": { "L": 4, "hidden_dim": 100, "out_dim": 70, "residual": true, "edge_feat": false, "readout": "sum", "in_feat_dropout": 0.0, "dropout": 0.0, "graph_norm": true, "batch_norm": true, "aggregators": "mean max min std", "scalers": "identity amplification attenuation", "towers": 5, "divide_input_first": true, "divide_input_last": false, "gru": false, "edge_dim": 0, "pretrans_layers" : 1, "posttrans_layers" : 1 } } ================================================ FILE: realworld_benchmark/data/HIV.py ================================================ import time import dgl import torch from torch.utils.data import Dataset from ogb.graphproppred import DglGraphPropPredDataset from ogb.graphproppred import Evaluator import torch.utils.data class HIVDGL(torch.utils.data.Dataset): def __init__(self, data, split): self.split = split self.data = [g for g in data[self.split]] self.graph_lists = [] self.graph_labels = [] for g in self.data: if g[0].number_of_nodes() > 5: self.graph_lists.append(g[0]) self.graph_labels.append(g[1]) self.n_samples = len(self.graph_lists) def __len__(self): """Return the number of graphs in the dataset.""" return self.n_samples def __getitem__(self, idx): """ Get the idx^th sample. Parameters --------- idx : int The sample index. Returns ------- (dgl.DGLGraph, int) DGLGraph with node feature stored in `feat` field And its label. """ return self.graph_lists[idx], self.graph_labels[idx] class HIVDataset(Dataset): def __init__(self, name, verbose=True): start = time.time() if verbose: print("[I] Loading dataset %s..." % (name)) self.name = name self.dataset = DglGraphPropPredDataset(name = 'ogbg-molhiv') self.split_idx = self.dataset.get_idx_split() self.train = HIVDGL(self.dataset, self.split_idx['train']) self.val = HIVDGL(self.dataset, self.split_idx['valid']) self.test = HIVDGL(self.dataset, self.split_idx['test']) self.evaluator = Evaluator(name='ogbg-molhiv') if verbose: print('train, test, val sizes :', len(self.train), len(self.test), len(self.val)) print("[I] Finished loading.") print("[I] Data load time: {:.4f}s".format(time.time() - start)) # form a mini batch from a given list of samples = [(graph, label) pairs] def collate(self, samples): # The input samples is a list of pairs (graph, label). graphs, labels = map(list, zip(*samples)) labels = torch.cat(labels).long() batched_graph = dgl.batch(graphs) return batched_graph, labels def _add_self_loops(self): # function for adding self loops # this function will be called only if self_loop flag is True self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists] self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists] self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists] ================================================ FILE: realworld_benchmark/data/download_datasets.sh ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson # Command to download dataset: # bash script_download_all_datasets.sh # ZINC FILE=ZINC.pkl if test -f "$FILE"; then echo -e "$FILE already downloaded." else echo -e "\ndownloading $FILE..." curl https://www.dropbox.com/s/bhimk9p1xst6dvo/ZINC.pkl?dl=1 -o ZINC.pkl -J -L -k fi # MNIST and CIFAR10 FILE=MNIST.pkl if test -f "$FILE"; then echo -e "$FILE already downloaded." else echo -e "\ndownloading $FILE..." curl https://www.dropbox.com/s/wcfmo4yvnylceaz/MNIST.pkl?dl=1 -o MNIST.pkl -J -L -k fi FILE=CIFAR10.pkl if test -f "$FILE"; then echo -e "$FILE already downloaded." else echo -e "\ndownloading $FILE..." curl https://www.dropbox.com/s/agocm8pxg5u8yb5/CIFAR10.pkl?dl=1 -o CIFAR10.pkl -J -L -k fi ================================================ FILE: realworld_benchmark/data/molecules.py ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson import torch import pickle import torch.utils.data import time import numpy as np import csv import dgl class MoleculeDGL(torch.utils.data.Dataset): def __init__(self, data_dir, split, num_graphs): self.data_dir = data_dir self.split = split self.num_graphs = num_graphs with open(data_dir + "/%s.pickle" % self.split, "rb") as f: self.data = pickle.load(f) # loading the sampled indices from file ./zinc_molecules/.index with open(data_dir + "/%s.index" % self.split, "r") as f: data_idx = [list(map(int, idx)) for idx in csv.reader(f)] self.data = [self.data[i] for i in data_idx[0]] assert len(self.data) == num_graphs, "Sample num_graphs again; available idx: train/val/test => 10k/1k/1k" """ data is a list of Molecule dict objects with following attributes molecule = data[idx] ; molecule['num_atom'] : nb of atoms, an integer (N) ; molecule['atom_type'] : tensor of size N, each element is an atom type, an integer between 0 and num_atom_type ; molecule['bond_type'] : tensor of size N x N, each element is a bond type, an integer between 0 and num_bond_type ; molecule['logP_SA_cycle_normalized'] : the chemical property to regress, a float variable """ self.graph_lists = [] self.graph_labels = [] self.n_samples = len(self.data) self._prepare() def _prepare(self): print("preparing %d graphs for the %s set..." % (self.num_graphs, self.split.upper())) for molecule in self.data: node_features = molecule['atom_type'].long() adj = molecule['bond_type'] edge_list = (adj != 0).nonzero() # converting adj matrix to edge_list edge_idxs_in_adj = edge_list.split(1, dim=1) edge_features = adj[edge_idxs_in_adj].reshape(-1).long() # Create the DGL Graph g = dgl.DGLGraph() g.add_nodes(molecule['num_atom']) g.ndata['feat'] = node_features for src, dst in edge_list: g.add_edges(src.item(), dst.item()) g.edata['feat'] = edge_features self.graph_lists.append(g) self.graph_labels.append(molecule['logP_SA_cycle_normalized']) def __len__(self): """Return the number of graphs in the dataset.""" return self.n_samples def __getitem__(self, idx): """ Get the idx^th sample. Parameters --------- idx : int The sample index. Returns ------- (dgl.DGLGraph, int) DGLGraph with node feature stored in `feat` field And its label. """ return self.graph_lists[idx], self.graph_labels[idx] class MoleculeDatasetDGL(torch.utils.data.Dataset): def __init__(self, name='Zinc'): t0 = time.time() self.name = name self.num_atom_type = 28 # known meta-info about the zinc dataset; can be calculated as well self.num_bond_type = 4 # known meta-info about the zinc dataset; can be calculated as well data_dir = './data/molecules' self.train = MoleculeDGL(data_dir, 'train', num_graphs=10000) self.val = MoleculeDGL(data_dir, 'val', num_graphs=1000) self.test = MoleculeDGL(data_dir, 'test', num_graphs=1000) print("Time taken: {:.4f}s".format(time.time() - t0)) def self_loop(g): """ Utility function only, to be used only when necessary as per user self_loop flag : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat'] This function is called inside a function in MoleculeDataset class. """ new_g = dgl.DGLGraph() new_g.add_nodes(g.number_of_nodes()) new_g.ndata['feat'] = g.ndata['feat'] src, dst = g.all_edges(order="eid") src = dgl.backend.zerocopy_to_numpy(src) dst = dgl.backend.zerocopy_to_numpy(dst) non_self_edges_idx = src != dst nodes = np.arange(g.number_of_nodes()) new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx]) new_g.add_edges(nodes, nodes) # This new edata is not used since this function gets called only for GCN, GAT # However, we need this for the generic requirement of ndata and edata new_g.edata['feat'] = torch.zeros(new_g.number_of_edges()) return new_g class MoleculeDataset(torch.utils.data.Dataset): def __init__(self, name): """ Loading SBM datasets """ start = time.time() print("[I] Loading dataset %s..." % (name)) self.name = name data_dir = 'data/' with open(data_dir + name + '.pkl', "rb") as f: f = pickle.load(f) self.train = f[0] self.val = f[1] self.test = f[2] self.num_atom_type = f[3] self.num_bond_type = f[4] print('train, test, val sizes :', len(self.train), len(self.test), len(self.val)) print("[I] Finished loading.") print("[I] Data load time: {:.4f}s".format(time.time() - start)) # form a mini batch from a given list of samples = [(graph, label) pairs] def collate(self, samples): # The input samples is a list of pairs (graph, label). graphs, labels = map(list, zip(*samples)) labels = torch.tensor(np.array(labels)).unsqueeze(1) tab_sizes_n = [graphs[i].number_of_nodes() for i in range(len(graphs))] tab_snorm_n = [torch.FloatTensor(size, 1).fill_(1. / float(size)) for size in tab_sizes_n] snorm_n = torch.cat(tab_snorm_n).sqrt() tab_sizes_e = [graphs[i].number_of_edges() for i in range(len(graphs))] tab_snorm_e = [torch.FloatTensor(size, 1).fill_(1. / float(size)) for size in tab_sizes_e] snorm_e = torch.cat(tab_snorm_e).sqrt() batched_graph = dgl.batch(graphs) return batched_graph, labels, snorm_n, snorm_e def _add_self_loops(self): # function for adding self loops # this function will be called only if self_loop flag is True self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists] self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists] self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists] ================================================ FILE: realworld_benchmark/data/superpixels.py ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson import os import pickle from scipy.spatial.distance import cdist import numpy as np import itertools import dgl import torch import torch.utils.data import time import csv from sklearn.model_selection import StratifiedShuffleSplit def sigma(dists, kth=8): # Compute sigma and reshape try: # Get k-nearest neighbors for each node knns = np.partition(dists, kth, axis=-1)[:, kth::-1] sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth except ValueError: # handling for graphs with num_nodes less than kth num_nodes = dists.shape[0] # this sigma value is irrelevant since not used for final compute_edge_list sigma = np.array([1]*num_nodes).reshape(num_nodes,1) return sigma + 1e-8 # adding epsilon to avoid zero value of sigma def compute_adjacency_matrix_images(coord, feat, use_feat=True, kth=8): coord = coord.reshape(-1, 2) # Compute coordinate distance c_dist = cdist(coord, coord) if use_feat: # Compute feature distance f_dist = cdist(feat, feat) # Compute adjacency A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 ) else: A = np.exp(- (c_dist/sigma(c_dist))**2) # Convert to symmetric matrix A = 0.5 * (A + A.T) A[np.diag_indices_from(A)] = 0 return A def compute_edges_list(A, kth=8+1): # Get k-similar neighbor indices for each node num_nodes = A.shape[0] new_kth = num_nodes - kth if num_nodes > 9: knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1] knn_values = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1] # NEW else: # handling for graphs with less than kth nodes # in such cases, the resulting graph will be fully connected knns = np.tile(np.arange(num_nodes), num_nodes).reshape(num_nodes, num_nodes) knn_values = A # NEW # removing self loop if num_nodes != 1: knn_values = A[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) # NEW knns = knns[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) return knns, knn_values # NEW class SuperPixDGL(torch.utils.data.Dataset): def __init__(self, data_dir, dataset, split, use_mean_px=True, use_coord=True): self.split = split self.graph_lists = [] if dataset == 'MNIST': self.img_size = 28 with open(os.path.join(data_dir, 'mnist_75sp_%s.pkl' % split), 'rb') as f: self.labels, self.sp_data = pickle.load(f) self.graph_labels = torch.LongTensor(self.labels) elif dataset == 'CIFAR10': self.img_size = 32 with open(os.path.join(data_dir, 'cifar10_150sp_%s.pkl' % split), 'rb') as f: self.labels, self.sp_data = pickle.load(f) self.graph_labels = torch.LongTensor(self.labels) self.use_mean_px = use_mean_px self.use_coord = use_coord self.n_samples = len(self.labels) self._prepare() def _prepare(self): print("preparing %d graphs for the %s set..." % (self.n_samples, self.split.upper())) self.Adj_matrices, self.node_features, self.edges_lists, self.edge_features = [], [], [], [] for index, sample in enumerate(self.sp_data): mean_px, coord = sample[:2] try: coord = coord / self.img_size except AttributeError: VOC_has_variable_image_sizes = True if self.use_mean_px: A = compute_adjacency_matrix_images(coord, mean_px) # using super-pixel locations + features else: A = compute_adjacency_matrix_images(coord, mean_px, False) # using only super-pixel locations edges_list, edge_values_list = compute_edges_list(A) # NEW N_nodes = A.shape[0] mean_px = mean_px.reshape(N_nodes, -1) coord = coord.reshape(N_nodes, 2) x = np.concatenate((mean_px, coord), axis=1) edge_values_list = edge_values_list.reshape(-1) # NEW # TO DOUBLE-CHECK ! self.node_features.append(x) self.edge_features.append(edge_values_list) # NEW self.Adj_matrices.append(A) self.edges_lists.append(edges_list) for index in range(len(self.sp_data)): g = dgl.DGLGraph() g.add_nodes(self.node_features[index].shape[0]) g.ndata['feat'] = torch.Tensor(self.node_features[index]).half() for src, dsts in enumerate(self.edges_lists[index]): # handling for 1 node where the self loop would be the only edge # since, VOC Superpixels has few samples (5 samples) with only 1 node if self.node_features[index].shape[0] == 1: g.add_edges(src, dsts) else: g.add_edges(src, dsts[dsts!=src]) # adding edge features for Residual Gated ConvNet edge_feat_dim = g.ndata['feat'].shape[1] # dim same as node feature dim #g.edata['feat'] = torch.ones(g.number_of_edges(), edge_feat_dim).half() g.edata['feat'] = torch.Tensor(self.edge_features[index]).unsqueeze(1).half() # NEW self.graph_lists.append(g) def __len__(self): """Return the number of graphs in the dataset.""" return self.n_samples def __getitem__(self, idx): """ Get the idx^th sample. Parameters --------- idx : int The sample index. Returns ------- (dgl.DGLGraph, int) DGLGraph with node feature stored in `feat` field And its label. """ return self.graph_lists[idx], self.graph_labels[idx] class DGLFormDataset(torch.utils.data.Dataset): """ DGLFormDataset wrapping graph list and label list as per pytorch Dataset. *lists (list): lists of 'graphs' and 'labels' with same len(). """ def __init__(self, *lists): assert all(len(lists[0]) == len(li) for li in lists) self.lists = lists self.graph_lists = lists[0] self.graph_labels = lists[1] def __getitem__(self, index): return tuple(li[index] for li in self.lists) def __len__(self): return len(self.lists[0]) class SuperPixDatasetDGL(torch.utils.data.Dataset): def __init__(self, name, num_val=5000): """ Takes input standard image dataset name (MNIST/CIFAR10) and returns the superpixels graph. This class uses results from the above SuperPix class. which contains the steps for the generation of the Superpixels graph from a superpixel .pkl file that has been given by https://github.com/bknyaz/graph_attention_pool Please refer the SuperPix class for details. """ t_data = time.time() self.name = name use_mean_px = True # using super-pixel locations + features use_mean_px = False # using only super-pixel locations if use_mean_px: print('Adj matrix defined from super-pixel locations + features') else: print('Adj matrix defined from super-pixel locations (only)') use_coord = True self.test = SuperPixDGL("./data/superpixels", dataset=self.name, split='test', use_mean_px=use_mean_px, use_coord=use_coord) self.train_ = SuperPixDGL("./data/superpixels", dataset=self.name, split='train', use_mean_px=use_mean_px, use_coord=use_coord) _val_graphs, _val_labels = self.train_[:num_val] _train_graphs, _train_labels = self.train_[num_val:] self.val = DGLFormDataset(_val_graphs, _val_labels) self.train = DGLFormDataset(_train_graphs, _train_labels) print("[I] Data load time: {:.4f}s".format(time.time()-t_data)) def self_loop(g): """ Utility function only, to be used only when necessary as per user self_loop flag : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat'] This function is called inside a function in SuperPixDataset class. """ new_g = dgl.DGLGraph() new_g.add_nodes(g.number_of_nodes()) new_g.ndata['feat'] = g.ndata['feat'] src, dst = g.all_edges(order="eid") src = dgl.backend.zerocopy_to_numpy(src) dst = dgl.backend.zerocopy_to_numpy(dst) non_self_edges_idx = src != dst nodes = np.arange(g.number_of_nodes()) new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx]) new_g.add_edges(nodes, nodes) # This new edata is not used since this function gets called only for GCN, GAT # However, we need this for the generic requirement of ndata and edata new_g.edata['feat'] = torch.zeros(new_g.number_of_edges()) return new_g class SuperPixDataset(torch.utils.data.Dataset): def __init__(self, name): """ Loading Superpixels datasets """ start = time.time() print("[I] Loading dataset %s..." % (name)) self.name = name data_dir = 'data/' with open(data_dir+name+'.pkl',"rb") as f: f = pickle.load(f) self.train = f[0] self.val = f[1] self.test = f[2] print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) print("[I] Finished loading.") print("[I] Data load time: {:.4f}s".format(time.time()-start)) # form a mini batch from a given list of samples = [(graph, label) pairs] def collate(self, samples): # The input samples is a list of pairs (graph, label). graphs, labels = map(list, zip(*samples)) labels = torch.tensor(np.array(labels)) tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] snorm_n = torch.cat(tab_snorm_n).sqrt() tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))] tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ] snorm_e = torch.cat(tab_snorm_e).sqrt() for idx, graph in enumerate(graphs): graphs[idx].ndata['feat'] = graph.ndata['feat'].float() graphs[idx].edata['feat'] = graph.edata['feat'].float() batched_graph = dgl.batch(graphs) return batched_graph, labels, snorm_n, snorm_e def _add_self_loops(self): # function for adding self loops # this function will be called only if self_loop flag is True self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists] self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists] self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists] self.train = DGLFormDataset(self.train.graph_lists, self.train.graph_labels) self.val = DGLFormDataset(self.val.graph_lists, self.val.graph_labels) self.test = DGLFormDataset(self.test.graph_lists, self.test.graph_labels) ================================================ FILE: realworld_benchmark/docs/setup.md ================================================ # Benchmark setup
## 1. Setup Conda ``` # Conda installation # For Linux curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh # For OSX curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x ~/miniconda.sh ~/miniconda.sh source ~/.bashrc # For Linux source ~/.bash_profile # For OSX ```
## 2. Setup Python environment for CPU ``` # Clone GitHub repo conda install git git clone https://github.com/lukecavabarrett/pna.git cd pna # Install python environment conda env create -f environment_cpu.yml # Activate environment conda activate benchmark_gnn ```
## 3. Setup Python environment for GPU DGL requires CUDA **10.0**. For Ubuntu **18.04** ``` # Setup CUDA 10.0 on Ubuntu 18.04 sudo apt-get --purge remove "*cublas*" "cuda*" sudo apt --purge remove "nvidia*" sudo apt autoremove wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-repo-ubuntu1804_10.0.130-1_amd64.deb sudo dpkg -i cuda-repo-ubuntu1804_10.0.130-1_amd64.deb sudo apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub sudo apt update sudo apt install -y cuda-10-0 sudo reboot cat /usr/local/cuda/version.txt # Check CUDA version is 10.0 # Clone GitHub repo conda install git git clone https://github.com/lukecavabarrett/pna.git cd pna # Install python environment conda env create -f environment_gpu.yml # Activate environment conda activate benchmark_gnn ``` For Ubuntu **16.04** ``` # Setup CUDA 10.0 on Ubuntu 16.04 sudo apt-get --purge remove "*cublas*" "cuda*" sudo apt --purge remove "nvidia*" sudo apt autoremove wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_10.0.130-1_amd64.deb sudo dpkg -i cuda-repo-ubuntu1604_10.0.130-1_amd64.deb sudo apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub sudo apt update sudo apt install -y cuda-10-0 sudo reboot cat /usr/local/cuda/version.txt # Check CUDA version is 10.0 # Clone GitHub repo conda install git git clone https://github.com/lukecavabarrett/pna.git cd pna # Install python environment conda env create -f environment_gpu.yml # Activate environment conda activate benchmark_gnn ``` ## 4. Download Datasets ``` # At the root of the repo cd realworld_benchmark/data/ bash download_datasets.sh ```


================================================ FILE: realworld_benchmark/environment_cpu.yml ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson name: benchmark_gnn channels: - pytorch - dglteam - conda-forge dependencies: - python=3.7.4 - python-dateutil=2.8.0 - pytorch=1.3 - torchvision==0.4.2 - pillow==6.1 - dgl=0.4.2 - numpy=1.16.4 - matplotlib=3.1.0 - tensorboard=1.14.0 - tensorboardx=1.8 - absl-py - networkx=2.3 - scikit-learn=0.21.2 - scipy=1.3.0 - notebook=6.0.0 - h5py=2.9.0 - mkl=2019.4 - ipykernel=5.1.2 - ipython=7.7.0 - ipython_genutils=0.2.0 - ipywidgets=7.5.1 - jupyter=1.0.0 - jupyter_client=5.3.1 - jupyter_console=6.0.0 - jupyter_core=4.5.0 - plotly=4.1.1 - scikit-image=0.15.0 - requests==2.22.0 - tqdm==4.43.0 - pip: - ogb==1.2.2 ================================================ FILE: realworld_benchmark/environment_gpu.yml ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson name: benchmark_gnn_gpu channels: - pytorch - dglteam - conda-forge - fragcolor dependencies: - cuda10.0 - cudatoolkit=10.0 - cudnn=7.6.5 - python=3.7.4 - python-dateutil=2.8.0 - pytorch=1.3 - torchvision==0.4.2 - pillow==6.1 - dgl-cuda10.0=0.4.2 - numpy=1.16.4 - matplotlib=3.1.0 - tensorboard=1.14.0 - tensorboardx=1.8 - absl-py - networkx=2.3 - scikit-learn=0.21.2 - scipy=1.3.0 - notebook=6.0.0 - h5py=2.9.0 - mkl=2019.4 - ipykernel=5.1.2 - ipython=7.7.0 - ipython_genutils=0.2.0 - ipywidgets=7.5.1 - jupyter=1.0.0 - jupyter_client=5.3.1 - jupyter_console=6.0.0 - jupyter_core=4.5.0 - plotly=4.1.1 - scikit-image=0.15.0 - requests==2.22.0 - tqdm==4.43.0 - pip: - ogb==1.2.2 ================================================ FILE: realworld_benchmark/main_HIV.py ================================================ import numpy as np import os import time import random import argparse, json import torch import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm from nets.HIV_graph_classification.pna_net import PNANet from data.HIV import HIVDataset # import dataset from train.train_HIV_graph_classification import train_epoch_sparse as train_epoch, \ evaluate_network_sparse as evaluate_network def gpu_setup(use_gpu, gpu_id): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) if torch.cuda.is_available() and use_gpu: print('cuda available with GPU:', torch.cuda.get_device_name(0)) device = torch.device("cuda") else: print('cuda not available') device = torch.device("cpu") return device def view_model_param(net_params): model = PNANet(net_params) total_param = 0 print("MODEL DETAILS:\n") # print(model) for param in model.parameters(): # print(param.data.size()) total_param += np.prod(list(param.data.size())) print('PNA Total parameters:', total_param) return total_param def train_val_pipeline(dataset, params, net_params): t0 = time.time() per_epoch_time = [] trainset, valset, testset = dataset.train, dataset.val, dataset.test device = net_params['device'] # setting seeds random.seed(params['seed']) np.random.seed(params['seed']) torch.manual_seed(params['seed']) if device.type == 'cuda': torch.cuda.manual_seed(params['seed']) print("Training Graphs: ", len(trainset)) print("Validation Graphs: ", len(valset)) print("Test Graphs: ", len(testset)) model = PNANet(net_params) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=params['lr_reduce_factor'], patience=params['lr_schedule_patience'], verbose=True) epoch_train_losses, epoch_val_losses = [], [] epoch_train_ROCs, epoch_val_ROCs, epoch_test_ROCs = [], [], [] train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate, pin_memory=True) val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate, pin_memory=True) test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate, pin_memory=True) # At any point you can hit Ctrl + C to break out of training early. try: with tqdm(range(params['epochs']), unit='epoch') as t: for epoch in t: if epoch == -1: model.reset_params() t.set_description('Epoch %d' % epoch) start = time.time() epoch_train_loss, epoch_train_roc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch) epoch_val_loss, epoch_val_roc = evaluate_network(model, device, val_loader, epoch) epoch_train_losses.append(epoch_train_loss) epoch_val_losses.append(epoch_val_loss) epoch_train_ROCs.append(epoch_train_roc.item()) epoch_val_ROCs.append(epoch_val_roc.item()) _, epoch_test_roc = evaluate_network(model, device, test_loader, epoch) epoch_test_ROCs.append(epoch_test_roc.item()) t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'], train_loss=epoch_train_loss, val_loss=epoch_val_loss, train_ROC=epoch_train_roc.item(), val_ROC=epoch_val_roc.item(), test_ROC=epoch_test_roc.item(), refresh=False) per_epoch_time.append(time.time() - start) scheduler.step(-epoch_val_roc.item()) if optimizer.param_groups[0]['lr'] < params['min_lr']: print("\n!! LR EQUAL TO MIN LR SET.") break # Stop training after params['max_time'] hours if time.time() - t0 > params['max_time'] * 3600: print('-' * 89) print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time'])) break print('') except KeyboardInterrupt: print('-' * 89) print('Exiting from training early because of KeyboardInterrupt') best_val_epoch = np.argmax(np.array(epoch_val_ROCs)) best_train_epoch = np.argmax(np.array(epoch_train_ROCs)) best_val_roc = epoch_val_ROCs[best_val_epoch] best_val_test_roc = epoch_test_ROCs[best_val_epoch] best_val_train_roc = epoch_train_ROCs[best_val_epoch] best_train_roc = epoch_train_ROCs[best_train_epoch] print("Best Train ROC: {:.4f}".format(best_train_roc)) print("Best Val ROC: {:.4f}".format(best_val_roc)) print("Test ROC of Best Val: {:.4f}".format(best_val_test_roc)) print("Train ROC of Best Val: {:.4f}".format(best_val_train_roc)) print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0)) print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details") parser.add_argument('--gpu_id', help="Please give a value for gpu id") parser.add_argument('--dataset', help="Please give a value for dataset name") parser.add_argument('--seed', help="Please give a value for seed") parser.add_argument('--epochs', type=int, help="Please give a value for epochs") parser.add_argument('--batch_size', help="Please give a value for batch_size") parser.add_argument('--init_lr', help="Please give a value for init_lr") parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor") parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience") parser.add_argument('--min_lr', help="Please give a value for min_lr") parser.add_argument('--weight_decay', help="Please give a value for weight_decay") parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval") parser.add_argument('--L', help="Please give a value for L") parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim") parser.add_argument('--out_dim', help="Please give a value for out_dim") parser.add_argument('--residual', help="Please give a value for residual") parser.add_argument('--edge_feat', help="Please give a value for edge_feat") parser.add_argument('--readout', help="Please give a value for readout") parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout") parser.add_argument('--dropout', help="Please give a value for dropout") parser.add_argument('--batch_norm', help="Please give a value for batch_norm") parser.add_argument('--max_time', help="Please give a value for max_time") parser.add_argument('--expid', help='Experiment id.') parser.add_argument('--aggregators', type=str, help='Aggregators to use.') parser.add_argument('--scalers', type=str, help='Scalers to use.') parser.add_argument('--posttrans_layers', type=int, help='posttrans_layers.') args = parser.parse_args() print(args.config) with open(args.config) as f: config = json.load(f) # device if args.gpu_id is not None: config['gpu']['id'] = int(args.gpu_id) config['gpu']['use'] = True device = gpu_setup(config['gpu']['use'], config['gpu']['id']) # dataset, out_dir if args.dataset is not None: DATASET_NAME = args.dataset else: DATASET_NAME = config['dataset'] dataset = HIVDataset(DATASET_NAME) # parameters params = config['params'] if args.seed is not None: params['seed'] = int(args.seed) if args.epochs is not None: params['epochs'] = int(args.epochs) if args.batch_size is not None: params['batch_size'] = int(args.batch_size) if args.init_lr is not None: params['init_lr'] = float(args.init_lr) if args.lr_reduce_factor is not None: params['lr_reduce_factor'] = float(args.lr_reduce_factor) if args.lr_schedule_patience is not None: params['lr_schedule_patience'] = int(args.lr_schedule_patience) if args.min_lr is not None: params['min_lr'] = float(args.min_lr) if args.weight_decay is not None: params['weight_decay'] = float(args.weight_decay) if args.print_epoch_interval is not None: params['print_epoch_interval'] = int(args.print_epoch_interval) if args.max_time is not None: params['max_time'] = float(args.max_time) # network parameters net_params = config['net_params'] net_params['device'] = device net_params['gpu_id'] = config['gpu']['id'] net_params['batch_size'] = params['batch_size'] if args.L is not None: net_params['L'] = int(args.L) if args.hidden_dim is not None: net_params['hidden_dim'] = int(args.hidden_dim) if args.out_dim is not None: net_params['out_dim'] = int(args.out_dim) if args.residual is not None: net_params['residual'] = True if args.residual == 'True' else False if args.edge_feat is not None: net_params['edge_feat'] = True if args.edge_feat == 'True' else False if args.readout is not None: net_params['readout'] = args.readout if args.in_feat_dropout is not None: net_params['in_feat_dropout'] = float(args.in_feat_dropout) if args.dropout is not None: net_params['dropout'] = float(args.dropout) if args.batch_norm is not None: net_params['batch_norm'] = True if args.batch_norm == 'True' else False if args.aggregators is not None: net_params['aggregators'] = args.aggregators if args.scalers is not None: net_params['scalers'] = args.scalers if args.posttrans_layers is not None: net_params['posttrans_layers'] = args.posttrans_layers D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g in dataset.train.graph_lists]) net_params['avg_d'] = dict(lin=torch.mean(D), exp=torch.mean(torch.exp(torch.div(1, D)) - 1), log=torch.mean(torch.log(D + 1))) net_params['total_param'] = view_model_param(net_params) train_val_pipeline(dataset, params, net_params) main() ================================================ FILE: realworld_benchmark/main_molecules.py ================================================ """ IMPORTING LIBS """ import numpy as np import os import time import random import argparse, json import torch import torch.optim as optim from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from tqdm import tqdm class DotDict(dict): def __init__(self, **kwds): self.update(kwds) self.__dict__ = self """ IMPORTING CUSTOM MODULES/METHODS """ from nets.molecules_graph_regression.pna_net import PNANet from data.molecules import MoleculeDataset # import dataset from train.train_molecules_graph_regression import train_epoch, evaluate_network """ GPU Setup """ def gpu_setup(use_gpu, gpu_id): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) if torch.cuda.is_available() and use_gpu: print('cuda available with GPU:', torch.cuda.get_device_name(0)) device = torch.device("cuda") else: print('cuda not available') device = torch.device("cpu") return device """ VIEWING MODEL CONFIG AND PARAMS """ def view_model_param(net_params): model = PNANet(net_params) total_param = 0 print("MODEL DETAILS:\n") # print(model) for param in model.parameters(): # print(param.data.size()) total_param += np.prod(list(param.data.size())) print('PNA Total parameters:', total_param) return total_param """ TRAINING CODE """ def train_val_pipeline(dataset, params, net_params, dirs): t0 = time.time() per_epoch_time = [] DATASET_NAME = dataset.name MODEL_NAME = 'PNA' trainset, valset, testset = dataset.train, dataset.val, dataset.test root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs device = net_params['device'] # Write the network and optimization hyper-parameters in folder config/ with open(write_config_file + '.txt', 'w') as f: f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""".format( DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param'])) log_dir = os.path.join(root_log_dir, "RUN_" + str(0)) writer = SummaryWriter(log_dir=log_dir) # setting seeds random.seed(params['seed']) np.random.seed(params['seed']) torch.manual_seed(params['seed']) if device.type == 'cuda': torch.cuda.manual_seed(params['seed']) print("Training Graphs: ", len(trainset)) print("Validation Graphs: ", len(valset)) print("Test Graphs: ", len(testset)) model = PNANet(net_params) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=params['lr_reduce_factor'], patience=params['lr_schedule_patience'], verbose=True) epoch_train_losses, epoch_val_losses = [], [] epoch_train_MAEs, epoch_val_MAEs = [], [] train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate) val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) # At any point you can hit Ctrl + C to break out of training early. try: with tqdm(range(params['epochs']), unit='epoch') as t: for epoch in t: t.set_description('Epoch %d' % epoch) start = time.time() epoch_train_loss, epoch_train_mae, optimizer = train_epoch(model, optimizer, device, train_loader, epoch) epoch_val_loss, epoch_val_mae = evaluate_network(model, device, val_loader, epoch) epoch_train_losses.append(epoch_train_loss) epoch_val_losses.append(epoch_val_loss) epoch_train_MAEs.append(epoch_train_mae.detach().cpu().item()) epoch_val_MAEs.append(epoch_val_mae.detach().cpu().item()) writer.add_scalar('train/_loss', epoch_train_loss, epoch) writer.add_scalar('val/_loss', epoch_val_loss, epoch) writer.add_scalar('train/_mae', epoch_train_mae, epoch) writer.add_scalar('val/_mae', epoch_val_mae, epoch) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) _, epoch_test_mae = evaluate_network(model, device, test_loader, epoch) t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'], train_loss=epoch_train_loss, val_loss=epoch_val_loss, train_MAE=epoch_train_mae.item(), val_MAE=epoch_val_mae.item(), test_MAE=epoch_test_mae.item(), refresh=False) per_epoch_time.append(time.time() - start) scheduler.step(epoch_val_loss) if optimizer.param_groups[0]['lr'] < params['min_lr']: print("\n!! LR EQUAL TO MIN LR SET.") break # Stop training after params['max_time'] hours if time.time() - t0 > params['max_time'] * 3600: print('-' * 89) print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time'])) break except KeyboardInterrupt: print('-' * 89) print('Exiting from training early because of KeyboardInterrupt') _, test_mae = evaluate_network(model, device, test_loader, epoch) _, val_mae = evaluate_network(model, device, val_loader, epoch) _, train_mae = evaluate_network(model, device, train_loader, epoch) test_mae = test_mae.item() val_mae = val_mae.item() train_mae = train_mae.item() print("Train MAE: {:.4f}".format(train_mae)) print("Val MAE: {:.4f}".format(val_mae)) print("Test MAE: {:.4f}".format(test_mae)) print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0)) print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) writer.close() """ Write the results in out_dir/results folder """ with open(write_file_name + '.txt', 'w') as f: f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n FINAL RESULTS\nTEST MAE: {:.4f}\nTRAIN MAE: {:.4f}\n\n Total Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n""" \ .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], np.mean(np.array(test_mae)), np.array(train_mae), (time.time() - t0) / 3600, np.mean(per_epoch_time))) def main(): """ USER CONTROLS """ parser = argparse.ArgumentParser() parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details") parser.add_argument('--gpu_id', help="Please give a value for gpu id") parser.add_argument('--model', help="Please give a value for model name") parser.add_argument('--dataset', help="Please give a value for dataset name") parser.add_argument('--out_dir', help="Please give a value for out_dir") parser.add_argument('--seed', help="Please give a value for seed") parser.add_argument('--epochs', help="Please give a value for epochs") parser.add_argument('--batch_size', help="Please give a value for batch_size") parser.add_argument('--init_lr', help="Please give a value for init_lr") parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor") parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience") parser.add_argument('--min_lr', help="Please give a value for min_lr") parser.add_argument('--weight_decay', help="Please give a value for weight_decay") parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval") parser.add_argument('--L', help="Please give a value for L") parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim") parser.add_argument('--out_dim', help="Please give a value for out_dim") parser.add_argument('--residual', help="Please give a value for residual") parser.add_argument('--edge_feat', help="Please give a value for edge_feat") parser.add_argument('--readout', help="Please give a value for readout") parser.add_argument('--kernel', help="Please give a value for kernel") parser.add_argument('--n_heads', help="Please give a value for n_heads") parser.add_argument('--gated', help="Please give a value for gated") parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout") parser.add_argument('--dropout', help="Please give a value for dropout") parser.add_argument('--graph_norm', help="Please give a value for graph_norm") parser.add_argument('--batch_norm', help="Please give a value for batch_norm") parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator") parser.add_argument('--data_mode', help="Please give a value for data_mode") parser.add_argument('--num_pool', help="Please give a value for num_pool") parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block") parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim") parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio") parser.add_argument('--linkpred', help="Please give a value for linkpred") parser.add_argument('--cat', help="Please give a value for cat") parser.add_argument('--self_loop', help="Please give a value for self_loop") parser.add_argument('--max_time', help="Please give a value for max_time") parser.add_argument('--expid', help='Experiment id.') # pna params parser.add_argument('--aggregators', type=str, help='Aggregators to use.') parser.add_argument('--scalers', type=str, help='Scalers to use.') parser.add_argument('--towers', type=int, help='Towers to use.') parser.add_argument('--divide_input_first', type=str, help='Whether to divide the input in first layers.') parser.add_argument('--divide_input_last', type=str, help='Whether to divide the input in last layer.') parser.add_argument('--gru', type=str, help='Whether to use gru.') parser.add_argument('--edge_dim', type=int, help='Size of edge embeddings.') parser.add_argument('--pretrans_layers', type=int, help='pretrans_layers.') parser.add_argument('--posttrans_layers', type=int, help='posttrans_layers.') args = parser.parse_args() with open(args.config) as f: config = json.load(f) # device if args.gpu_id is not None: config['gpu']['id'] = int(args.gpu_id) config['gpu']['use'] = True device = gpu_setup(config['gpu']['use'], config['gpu']['id']) # dataset, out_dir if args.dataset is not None: DATASET_NAME = args.dataset else: DATASET_NAME = config['dataset'] dataset = MoleculeDataset(DATASET_NAME) if args.out_dir is not None: out_dir = args.out_dir else: out_dir = config['out_dir'] # parameters params = config['params'] if args.seed is not None: params['seed'] = int(args.seed) if args.epochs is not None: params['epochs'] = int(args.epochs) if args.batch_size is not None: params['batch_size'] = int(args.batch_size) if args.init_lr is not None: params['init_lr'] = float(args.init_lr) if args.lr_reduce_factor is not None: params['lr_reduce_factor'] = float(args.lr_reduce_factor) if args.lr_schedule_patience is not None: params['lr_schedule_patience'] = int(args.lr_schedule_patience) if args.min_lr is not None: params['min_lr'] = float(args.min_lr) if args.weight_decay is not None: params['weight_decay'] = float(args.weight_decay) if args.print_epoch_interval is not None: params['print_epoch_interval'] = int(args.print_epoch_interval) if args.max_time is not None: params['max_time'] = float(args.max_time) # network parameters net_params = config['net_params'] net_params['device'] = device net_params['gpu_id'] = config['gpu']['id'] net_params['batch_size'] = params['batch_size'] if args.L is not None: net_params['L'] = int(args.L) if args.hidden_dim is not None: net_params['hidden_dim'] = int(args.hidden_dim) if args.out_dim is not None: net_params['out_dim'] = int(args.out_dim) if args.residual is not None: net_params['residual'] = True if args.residual == 'True' else False if args.edge_feat is not None: net_params['edge_feat'] = True if args.edge_feat == 'True' else False if args.readout is not None: net_params['readout'] = args.readout if args.kernel is not None: net_params['kernel'] = int(args.kernel) if args.n_heads is not None: net_params['n_heads'] = int(args.n_heads) if args.gated is not None: net_params['gated'] = True if args.gated == 'True' else False if args.in_feat_dropout is not None: net_params['in_feat_dropout'] = float(args.in_feat_dropout) if args.dropout is not None: net_params['dropout'] = float(args.dropout) if args.graph_norm is not None: net_params['graph_norm'] = True if args.graph_norm == 'True' else False if args.batch_norm is not None: net_params['batch_norm'] = True if args.batch_norm == 'True' else False if args.sage_aggregator is not None: net_params['sage_aggregator'] = args.sage_aggregator if args.data_mode is not None: net_params['data_mode'] = args.data_mode if args.num_pool is not None: net_params['num_pool'] = int(args.num_pool) if args.gnn_per_block is not None: net_params['gnn_per_block'] = int(args.gnn_per_block) if args.embedding_dim is not None: net_params['embedding_dim'] = int(args.embedding_dim) if args.pool_ratio is not None: net_params['pool_ratio'] = float(args.pool_ratio) if args.linkpred is not None: net_params['linkpred'] = True if args.linkpred == 'True' else False if args.cat is not None: net_params['cat'] = True if args.cat == 'True' else False if args.self_loop is not None: net_params['self_loop'] = True if args.self_loop == 'True' else False if args.aggregators is not None: net_params['aggregators'] = args.aggregators if args.scalers is not None: net_params['scalers'] = args.scalers if args.towers is not None: net_params['towers'] = args.towers if args.divide_input_first is not None: net_params['divide_input_first'] = True if args.divide_input_first == 'True' else False if args.divide_input_last is not None: net_params['divide_input_last'] = True if args.divide_input_last == 'True' else False if args.gru is not None: net_params['gru'] = True if args.gru == 'True' else False if args.edge_dim is not None: net_params['edge_dim'] = args.edge_dim if args.pretrans_layers is not None: net_params['pretrans_layers'] = args.pretrans_layers if args.posttrans_layers is not None: net_params['posttrans_layers'] = args.posttrans_layers # ZINC net_params['num_atom_type'] = dataset.num_atom_type net_params['num_bond_type'] = dataset.num_bond_type MODEL_NAME = 'PNA' D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g in dataset.train.graph_lists]) net_params['avg_d'] = dict(lin=torch.mean(D), exp=torch.mean(torch.exp(torch.div(1, D)) - 1), log=torch.mean(torch.log(D + 1))) root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file if not os.path.exists(out_dir + 'results'): os.makedirs(out_dir + 'results') if not os.path.exists(out_dir + 'configs'): os.makedirs(out_dir + 'configs') net_params['total_param'] = view_model_param(net_params) train_val_pipeline(dataset, params, net_params, dirs) main() ================================================ FILE: realworld_benchmark/main_superpixels.py ================================================ """ IMPORTING LIBS """ import numpy as np import os import socket import time import random import glob import argparse, json import pickle import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from tqdm import tqdm class DotDict(dict): def __init__(self, **kwds): self.update(kwds) self.__dict__ = self """ IMPORTING CUSTOM MODULES/METHODS """ from nets.superpixels_graph_classification.pna_net import PNANet from data.superpixels import SuperPixDataset # import dataset from train.train_superpixels_graph_classification import train_epoch, \ evaluate_network # import train functions """ GPU Setup """ def gpu_setup(use_gpu, gpu_id): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) if torch.cuda.is_available() and use_gpu: print('cuda available with GPU:', torch.cuda.get_device_name(0)) device = torch.device("cuda") else: print('cuda not available') device = torch.device("cpu") return device """ VIEWING MODEL CONFIG AND PARAMS """ def view_model_param(MODEL_NAME, net_params): model = PNANet(net_params) total_param = 0 print("MODEL DETAILS:\n") # print(model) for param in model.parameters(): # print(param.data.size()) total_param += np.prod(list(param.data.size())) print('MODEL/Total parameters:', MODEL_NAME, total_param) return total_param """ TRAINING CODE """ def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs): t0 = time.time() per_epoch_time = [] DATASET_NAME = dataset.name trainset, valset, testset = dataset.train, dataset.val, dataset.test root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs device = net_params['device'] # Write the network and optimization hyper-parameters in folder config/ with open(write_config_file + '.txt', 'w') as f: f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""".format( DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param'])) log_dir = os.path.join(root_log_dir, "RUN_" + str(0)) writer = SummaryWriter(log_dir=log_dir) # setting seeds random.seed(params['seed']) np.random.seed(params['seed']) torch.manual_seed(params['seed']) if device.type == 'cuda': torch.cuda.manual_seed(params['seed']) print("Training Graphs: ", len(trainset)) print("Validation Graphs: ", len(valset)) print("Test Graphs: ", len(testset)) model = PNANet(net_params) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=params['lr_reduce_factor'], patience=params['lr_schedule_patience'], verbose=True) epoch_train_losses, epoch_val_losses = [], [] epoch_train_accs, epoch_val_accs = [], [] train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate) val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) # At any point you can hit Ctrl + C to break out of training early. try: with tqdm(range(params['epochs']), unit='epoch') as t: for epoch in t: t.set_description('Epoch %d' % epoch) start = time.time() epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch) epoch_val_loss, epoch_val_acc = evaluate_network(model, device, val_loader, epoch) epoch_train_losses.append(epoch_train_loss) epoch_val_losses.append(epoch_val_loss) epoch_train_accs.append(epoch_train_acc) epoch_val_accs.append(epoch_val_acc) writer.add_scalar('train/_loss', epoch_train_loss, epoch) writer.add_scalar('val/_loss', epoch_val_loss, epoch) writer.add_scalar('train/_acc', epoch_train_acc, epoch) writer.add_scalar('val/_acc', epoch_val_acc, epoch) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch) t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'], train_loss=epoch_train_loss, val_loss=epoch_val_loss, train_acc=epoch_train_acc, val_acc=epoch_val_acc, test_acc=epoch_test_acc) per_epoch_time.append(time.time() - start) scheduler.step(epoch_val_loss) if optimizer.param_groups[0]['lr'] < params['min_lr']: print("\n!! LR EQUAL TO MIN LR SET.") break # Stop training after params['max_time'] hours if time.time() - t0 > params['max_time'] * 3600: print('-' * 89) print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time'])) break except KeyboardInterrupt: print('-' * 89) print('Exiting from training early because of KeyboardInterrupt') _, test_acc = evaluate_network(model, device, test_loader, epoch) _, val_acc = evaluate_network(model, device, val_loader, epoch) _, train_acc = evaluate_network(model, device, train_loader, epoch) print("Test Accuracy: {:.4f}".format(test_acc)) print("Val Accuracy: {:.4f}".format(val_acc)) print("Train Accuracy: {:.4f}".format(train_acc)) print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0)) print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) writer.close() """ Write the results in out_dir/results folder """ with open(write_file_name + '.txt', 'w') as f: f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n Total Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n""" \ .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], np.mean(np.array(test_acc)) * 100, np.mean(np.array(train_acc)) * 100, (time.time() - t0) / 3600, np.mean(per_epoch_time))) # send results to gmail try: from gmail import send subject = 'Result for Dataset: {}, Model: {}'.format(DATASET_NAME, MODEL_NAME) body = """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n Total Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n""" \ .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], np.mean(np.array(test_acc)) * 100, np.mean(np.array(train_acc)) * 100, (time.time() - t0) / 3600, np.mean(per_epoch_time)) send(subject, body) except: pass def main(): """ USER CONTROLS """ parser = argparse.ArgumentParser() parser.add_argument('--config', help="Please give a config.json file with training/model/data/param details") parser.add_argument('--gpu_id', help="Please give a value for gpu id") parser.add_argument('--model', help="Please give a value for model name") parser.add_argument('--dataset', help="Please give a value for dataset name") parser.add_argument('--out_dir', help="Please give a value for out_dir") parser.add_argument('--seed', help="Please give a value for seed") parser.add_argument('--epochs', help="Please give a value for epochs") parser.add_argument('--batch_size', help="Please give a value for batch_size") parser.add_argument('--init_lr', help="Please give a value for init_lr") parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor") parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience") parser.add_argument('--min_lr', help="Please give a value for min_lr") parser.add_argument('--weight_decay', help="Please give a value for weight_decay") parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval") parser.add_argument('--L', help="Please give a value for L") parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim") parser.add_argument('--out_dim', help="Please give a value for out_dim") parser.add_argument('--residual', help="Please give a value for residual") parser.add_argument('--edge_feat', help="Please give a value for edge_feat") parser.add_argument('--readout', help="Please give a value for readout") parser.add_argument('--kernel', help="Please give a value for kernel") parser.add_argument('--n_heads', help="Please give a value for n_heads") parser.add_argument('--gated', help="Please give a value for gated") parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout") parser.add_argument('--dropout', help="Please give a value for dropout") parser.add_argument('--graph_norm', help="Please give a value for graph_norm") parser.add_argument('--batch_norm', help="Please give a value for batch_norm") parser.add_argument('--sage_aggregator', help="Please give a value for sage_aggregator") parser.add_argument('--data_mode', help="Please give a value for data_mode") parser.add_argument('--num_pool', help="Please give a value for num_pool") parser.add_argument('--gnn_per_block', help="Please give a value for gnn_per_block") parser.add_argument('--embedding_dim', help="Please give a value for embedding_dim") parser.add_argument('--pool_ratio', help="Please give a value for pool_ratio") parser.add_argument('--linkpred', help="Please give a value for linkpred") parser.add_argument('--cat', help="Please give a value for cat") parser.add_argument('--self_loop', help="Please give a value for self_loop") parser.add_argument('--max_time', help="Please give a value for max_time") parser.add_argument('--expid', help='Experiment id.') # pna params parser.add_argument('--aggregators', type=str, help='Aggregators to use.') parser.add_argument('--scalers', type=str, help='Scalers to use.') parser.add_argument('--towers', type=int, help='Towers to use.') parser.add_argument('--divide_input_first', type=str, help='Whether to divide the input in first layers.') parser.add_argument('--divide_input_last', type=str, help='Whether to divide the input in last layer.') parser.add_argument('--gru', type=str, help='Whether to use gru.') parser.add_argument('--edge_dim', type=int, help='Size of edge embeddings.') parser.add_argument('--pretrans_layers', type=int, help='pretrans_layers.') parser.add_argument('--posttrans_layers', type=int, help='posttrans_layers.') args = parser.parse_args() with open(args.config) as f: config = json.load(f) # device if args.gpu_id is not None: config['gpu']['id'] = int(args.gpu_id) config['gpu']['use'] = True device = gpu_setup(config['gpu']['use'], config['gpu']['id']) # model, dataset, out_dir if args.model is not None: MODEL_NAME = args.model else: MODEL_NAME = config['model'] if args.dataset is not None: DATASET_NAME = args.dataset else: DATASET_NAME = config['dataset'] dataset = SuperPixDataset(DATASET_NAME) if args.out_dir is not None: out_dir = args.out_dir else: out_dir = config['out_dir'] # parameters params = config['params'] if args.seed is not None: params['seed'] = int(args.seed) if args.epochs is not None: params['epochs'] = int(args.epochs) if args.batch_size is not None: params['batch_size'] = int(args.batch_size) if args.init_lr is not None: params['init_lr'] = float(args.init_lr) if args.lr_reduce_factor is not None: params['lr_reduce_factor'] = float(args.lr_reduce_factor) if args.lr_schedule_patience is not None: params['lr_schedule_patience'] = int(args.lr_schedule_patience) if args.min_lr is not None: params['min_lr'] = float(args.min_lr) if args.weight_decay is not None: params['weight_decay'] = float(args.weight_decay) if args.print_epoch_interval is not None: params['print_epoch_interval'] = int(args.print_epoch_interval) if args.max_time is not None: params['max_time'] = float(args.max_time) # network parameters net_params = config['net_params'] net_params['device'] = device net_params['gpu_id'] = config['gpu']['id'] net_params['batch_size'] = params['batch_size'] if args.L is not None: net_params['L'] = int(args.L) if args.hidden_dim is not None: net_params['hidden_dim'] = int(args.hidden_dim) if args.out_dim is not None: net_params['out_dim'] = int(args.out_dim) if args.residual is not None: net_params['residual'] = True if args.residual == 'True' else False if args.edge_feat is not None: net_params['edge_feat'] = True if args.edge_feat == 'True' else False if args.readout is not None: net_params['readout'] = args.readout if args.kernel is not None: net_params['kernel'] = int(args.kernel) if args.n_heads is not None: net_params['n_heads'] = int(args.n_heads) if args.gated is not None: net_params['gated'] = True if args.gated == 'True' else False if args.in_feat_dropout is not None: net_params['in_feat_dropout'] = float(args.in_feat_dropout) if args.dropout is not None: net_params['dropout'] = float(args.dropout) if args.graph_norm is not None: net_params['graph_norm'] = True if args.graph_norm == 'True' else False if args.batch_norm is not None: net_params['batch_norm'] = True if args.batch_norm == 'True' else False if args.sage_aggregator is not None: net_params['sage_aggregator'] = args.sage_aggregator if args.data_mode is not None: net_params['data_mode'] = args.data_mode if args.num_pool is not None: net_params['num_pool'] = int(args.num_pool) if args.gnn_per_block is not None: net_params['gnn_per_block'] = int(args.gnn_per_block) if args.embedding_dim is not None: net_params['embedding_dim'] = int(args.embedding_dim) if args.pool_ratio is not None: net_params['pool_ratio'] = float(args.pool_ratio) if args.linkpred is not None: net_params['linkpred'] = True if args.linkpred == 'True' else False if args.cat is not None: net_params['cat'] = True if args.cat == 'True' else False if args.self_loop is not None: net_params['self_loop'] = True if args.self_loop == 'True' else False if args.aggregators is not None: net_params['aggregators'] = args.aggregators if args.scalers is not None: net_params['scalers'] = args.scalers if args.towers is not None: net_params['towers'] = args.towers if args.divide_input_first is not None: net_params['divide_input_first'] = True if args.divide_input_first == 'True' else False if args.divide_input_last is not None: net_params['divide_input_last'] = True if args.divide_input_last == 'True' else False if args.gru is not None: net_params['gru'] = True if args.args == 'True' else False if args.edge_dim is not None: net_params['edge_dim'] = args.edge_dim if args.pretrans_layers is not None: net_params['pretrans_layers'] = args.pretrans_layers if args.posttrans_layers is not None: net_params['posttrans_layers'] = args.posttrans_layers # Superpixels net_params['in_dim'] = dataset.train[0][0].ndata['feat'][0].size(0) net_params['in_dim_edge'] = dataset.train[0][0].edata['feat'][0].size(0) num_classes = len(np.unique(np.array(dataset.train[:][1]))) net_params['n_classes'] = num_classes if MODEL_NAME == 'PNA': D = torch.cat([torch.sparse.sum(g.adjacency_matrix(transpose=True), dim=-1).to_dense() for g in dataset.train.graph_lists]) net_params['avg_d'] = dict(lin=torch.mean(D), exp=torch.mean(torch.exp(torch.div(1, D)) - 1), log=torch.mean(torch.log(D + 1))) root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str( config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y') dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file if not os.path.exists(out_dir + 'results'): os.makedirs(out_dir + 'results') if not os.path.exists(out_dir + 'configs'): os.makedirs(out_dir + 'configs') net_params['total_param'] = view_model_param(MODEL_NAME, net_params) train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs) main() ================================================ FILE: realworld_benchmark/nets/HIV_graph_classification/pna_net.py ================================================ import torch.nn as nn import dgl from models.dgl.pna_layer import PNASimpleLayer from nets.mlp_readout_layer import MLPReadout import torch from ogb.graphproppred.mol_encoder import AtomEncoder class PNANet(nn.Module): def __init__(self, net_params): super().__init__() hidden_dim = net_params['hidden_dim'] out_dim = net_params['out_dim'] in_feat_dropout = net_params['in_feat_dropout'] dropout = net_params['dropout'] n_layers = net_params['L'] self.readout = net_params['readout'] self.batch_norm = net_params['batch_norm'] self.aggregators = net_params['aggregators'] self.scalers = net_params['scalers'] self.avg_d = net_params['avg_d'] self.residual = net_params['residual'] posttrans_layers = net_params['posttrans_layers'] device = net_params['device'] self.device = device self.in_feat_dropout = nn.Dropout(in_feat_dropout) self.embedding_h = AtomEncoder(emb_dim=hidden_dim) self.layers = nn.ModuleList( [PNASimpleLayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, posttrans_layers=posttrans_layers) for _ in range(n_layers - 1)]) self.layers.append(PNASimpleLayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, posttrans_layers=posttrans_layers)) self.MLP_layer = MLPReadout(out_dim, 1) # 1 out dim since regression problem def forward(self, g, h): h = self.embedding_h(h) h = self.in_feat_dropout(h) for i, conv in enumerate(self.layers): h = conv(g, h) g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg) def loss(self, scores, labels): loss = torch.nn.BCEWithLogitsLoss()(scores, labels.type(torch.FloatTensor).to('cuda').unsqueeze(-1)) return loss ================================================ FILE: realworld_benchmark/nets/gru.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class GRU(nn.Module): """ Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself """ def __init__(self, input_size, hidden_size, device): super(GRU, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size).to(device) def forward(self, x, y): """ :param x: shape: (B, N, Din) where Din <= input_size (difference is padded) :param y: shape: (B, N, Dh) where Dh <= hidden_size (difference is padded) :return: shape: (B, N, Dh) """ assert (x.shape[-1] <= self.input_size and y.shape[-1] <= self.hidden_size) x = x.unsqueeze(0) y = y.unsqueeze(0) x = self.gru(x, y)[1] x = x.squeeze() return x ================================================ FILE: realworld_benchmark/nets/mlp_readout_layer.py ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson import torch import torch.nn as nn import torch.nn.functional as F """ MLP Layer used after graph vector representation """ class MLPReadout(nn.Module): def __init__(self, input_dim, output_dim, L=2): # L=nb_hidden_layers super().__init__() list_FC_layers = [nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True) for l in range(L)] list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True)) self.FC_layers = nn.ModuleList(list_FC_layers) self.L = L def forward(self, x): y = x for l in range(self.L): y = self.FC_layers[l](y) y = F.relu(y) y = self.FC_layers[self.L](y) return y ================================================ FILE: realworld_benchmark/nets/molecules_graph_regression/pna_net.py ================================================ import torch.nn as nn import dgl from nets.gru import GRU from models.dgl.pna_layer import PNALayer from nets.mlp_readout_layer import MLPReadout """ PNA: Principal Neighbourhood Aggregation Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic https://arxiv.org/abs/2004.05718 Architecture follows that in https://github.com/graphdeeplearning/benchmarking-gnns """ class PNANet(nn.Module): def __init__(self, net_params): super().__init__() num_atom_type = net_params['num_atom_type'] num_bond_type = net_params['num_bond_type'] hidden_dim = net_params['hidden_dim'] out_dim = net_params['out_dim'] in_feat_dropout = net_params['in_feat_dropout'] dropout = net_params['dropout'] n_layers = net_params['L'] self.readout = net_params['readout'] self.graph_norm = net_params['graph_norm'] self.batch_norm = net_params['batch_norm'] self.residual = net_params['residual'] self.aggregators = net_params['aggregators'] self.scalers = net_params['scalers'] self.avg_d = net_params['avg_d'] self.towers = net_params['towers'] self.divide_input_first = net_params['divide_input_first'] self.divide_input_last = net_params['divide_input_last'] self.edge_feat = net_params['edge_feat'] edge_dim = net_params['edge_dim'] pretrans_layers = net_params['pretrans_layers'] posttrans_layers = net_params['posttrans_layers'] self.gru_enable = net_params['gru'] device = net_params['device'] self.in_feat_dropout = nn.Dropout(in_feat_dropout) self.embedding_h = nn.Embedding(num_atom_type, hidden_dim) if self.edge_feat: self.embedding_e = nn.Embedding(num_bond_type, edge_dim) self.layers = nn.ModuleList([PNALayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, graph_norm=self.graph_norm, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, towers=self.towers, edge_features=self.edge_feat, edge_dim=edge_dim, divide_input=self.divide_input_first, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers) for _ in range(n_layers - 1)]) self.layers.append(PNALayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, graph_norm=self.graph_norm, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, towers=self.towers, divide_input=self.divide_input_last, edge_features=self.edge_feat, edge_dim=edge_dim, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers)) if self.gru_enable: self.gru = GRU(hidden_dim, hidden_dim, device) self.MLP_layer = MLPReadout(out_dim, 1) # 1 out dim since regression problem def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) h = self.in_feat_dropout(h) if self.edge_feat: e = self.embedding_e(e) for i, conv in enumerate(self.layers): h_t = conv(g, h, e, snorm_n) if self.gru_enable and i != len(self.layers) - 1: h_t = self.gru(h, h_t) h = h_t g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg) def loss(self, scores, targets): loss = nn.L1Loss()(scores, targets) return loss ================================================ FILE: realworld_benchmark/nets/superpixels_graph_classification/pna_net.py ================================================ import torch.nn as nn import dgl from nets.gru import GRU from models.dgl.pna_layer import PNALayer from nets.mlp_readout_layer import MLPReadout """ PNA: Principal Neighbourhood Aggregation Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic https://arxiv.org/abs/2004.05718 Architecture follows that in https://github.com/graphdeeplearning/benchmarking-gnns """ class PNANet(nn.Module): def __init__(self, net_params): super().__init__() in_dim = net_params['in_dim'] in_dim_edge = net_params['in_dim_edge'] hidden_dim = net_params['hidden_dim'] out_dim = net_params['out_dim'] n_classes = net_params['n_classes'] in_feat_dropout = net_params['in_feat_dropout'] dropout = net_params['dropout'] n_layers = net_params['L'] self.readout = net_params['readout'] self.graph_norm = net_params['graph_norm'] self.batch_norm = net_params['batch_norm'] self.residual = net_params['residual'] self.aggregators = net_params['aggregators'] self.scalers = net_params['scalers'] self.avg_d = net_params['avg_d'] self.towers = net_params['towers'] self.divide_input_first = net_params['divide_input_first'] self.divide_input_last = net_params['divide_input_last'] self.edge_feat = net_params['edge_feat'] edge_dim = net_params['edge_dim'] pretrans_layers = net_params['pretrans_layers'] posttrans_layers = net_params['posttrans_layers'] self.gru_enable = net_params['gru'] device = net_params['device'] self.embedding_h = nn.Linear(in_dim, hidden_dim) if self.edge_feat: self.embedding_e = nn.Linear(in_dim_edge, edge_dim) self.layers = nn.ModuleList([PNALayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout, graph_norm=self.graph_norm, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, towers=self.towers, edge_features=self.edge_feat, edge_dim=edge_dim, divide_input=self.divide_input_first, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers) for _ in range(n_layers - 1)]) self.layers.append(PNALayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout, graph_norm=self.graph_norm, batch_norm=self.batch_norm, residual=self.residual, aggregators=self.aggregators, scalers=self.scalers, avg_d=self.avg_d, towers=self.towers, divide_input=self.divide_input_last, edge_features=self.edge_feat, edge_dim=edge_dim, pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers)) if self.gru_enable: self.gru = GRU(hidden_dim, hidden_dim, device) self.MLP_layer = MLPReadout(out_dim, n_classes) def forward(self, g, h, e, snorm_n, snorm_e): h = self.embedding_h(h) if self.edge_feat: e = self.embedding_e(e) for i, conv in enumerate(self.layers): h_t = conv(g, h, e, snorm_n) if self.gru_enable and i != len(self.layers) - 1: h_t = self.gru(h, h_t) h = h_t g.ndata['h'] = h if self.readout == "sum": hg = dgl.sum_nodes(g, 'h') elif self.readout == "max": hg = dgl.max_nodes(g, 'h') elif self.readout == "mean": hg = dgl.mean_nodes(g, 'h') else: hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes return self.MLP_layer(hg) def loss(self, pred, label): criterion = nn.CrossEntropyLoss() loss = criterion(pred, label) return loss ================================================ FILE: realworld_benchmark/train/metrics.py ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson import torch import torch.nn as nn import torch.nn.functional as F from sklearn.metrics import confusion_matrix from sklearn.metrics import f1_score import numpy as np def MAE(scores, targets): MAE = F.l1_loss(scores, targets) return MAE def accuracy_TU(scores, targets): scores = scores.detach().argmax(dim=1) acc = (scores==targets).float().sum().item() return acc def accuracy_MNIST_CIFAR(scores, targets): scores = scores.detach().argmax(dim=1) acc = (scores==targets).float().sum().item() return acc def accuracy_CITATION_GRAPH(scores, targets): scores = scores.detach().argmax(dim=1) acc = (scores==targets).float().sum().item() acc = acc / len(targets) return acc def accuracy_SBM(scores, targets): S = targets.cpu().numpy() C = np.argmax( torch.nn.Softmax(dim=0)(scores).cpu().detach().numpy() , axis=1 ) CM = confusion_matrix(S,C).astype(np.float32) nb_classes = CM.shape[0] targets = targets.cpu().detach().numpy() nb_non_empty_classes = 0 pr_classes = np.zeros(nb_classes) for r in range(nb_classes): cluster = np.where(targets==r)[0] if cluster.shape[0] != 0: pr_classes[r] = CM[r,r]/ float(cluster.shape[0]) if CM[r,r]>0: nb_non_empty_classes += 1 else: pr_classes[r] = 0.0 acc = 100.* np.sum(pr_classes)/ float(nb_non_empty_classes) return acc def binary_f1_score(scores, targets): """Computes the F1 score using scikit-learn for binary class labels. Returns the F1 score for the positive class, i.e. labelled '1'. """ y_true = targets.cpu().numpy() y_pred = scores.argmax(dim=1).cpu().numpy() return f1_score(y_true, y_pred, average='binary') def accuracy_VOC(scores, targets): scores = scores.detach().argmax(dim=1).cpu() targets = targets.cpu().detach().numpy() acc = f1_score(scores, targets, average='weighted') return acc ================================================ FILE: realworld_benchmark/train/train_HIV_graph_classification.py ================================================ import torch from ogb.graphproppred import Evaluator def train_epoch_sparse(model, optimizer, device, data_loader, epoch): model.train() epoch_loss = 0 list_scores = [] list_labels = [] for iter, (batch_graphs, batch_labels) in enumerate(data_loader): batch_x = batch_graphs.ndata['feat'].to(device) # num x feat batch_labels = batch_labels.to(device) optimizer.zero_grad() batch_scores = model.forward(batch_graphs, batch_x) loss = model.loss(batch_scores, batch_labels) loss.backward() optimizer.step() epoch_loss += loss.detach().item() list_scores.append(batch_scores.detach()) list_labels.append(batch_labels.detach().unsqueeze(-1)) epoch_loss /= (iter + 1) evaluator = Evaluator(name='ogbg-molhiv') epoch_train_ROC = evaluator.eval({'y_pred': torch.cat(list_scores), 'y_true': torch.cat(list_labels)})['rocauc'] return epoch_loss, epoch_train_ROC, optimizer def evaluate_network_sparse(model, device, data_loader, epoch): model.eval() epoch_test_loss = 0 epoch_test_ROC = 0 with torch.no_grad(): list_scores = [] list_labels = [] for iter, (batch_graphs, batch_labels) in enumerate(data_loader): batch_x = batch_graphs.ndata['feat'].to(device) batch_labels = batch_labels.to(device) batch_scores = model.forward(batch_graphs, batch_x) loss = model.loss(batch_scores, batch_labels) epoch_test_loss += loss.detach().item() list_scores.append(batch_scores.detach()) list_labels.append(batch_labels.detach().unsqueeze(-1)) epoch_test_loss /= (iter + 1) evaluator = Evaluator(name='ogbg-molhiv') epoch_test_ROC = evaluator.eval({'y_pred': torch.cat(list_scores), 'y_true': torch.cat(list_labels)})['rocauc'] return epoch_test_loss, epoch_test_ROC ================================================ FILE: realworld_benchmark/train/train_molecules_graph_regression.py ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson """ Utility functions for training one epoch and evaluating one epoch """ import torch import torch.nn as nn import math from .metrics import MAE def train_epoch(model, optimizer, device, data_loader, epoch): model.train() epoch_loss = 0 epoch_train_mae = 0 nb_data = 0 gpu_mem = 0 for iter, (batch_graphs, batch_targets, batch_snorm_n, batch_snorm_e) in enumerate(data_loader): batch_x = batch_graphs.ndata['feat'].to(device) # num x feat batch_e = batch_graphs.edata['feat'].to(device) batch_snorm_e = batch_snorm_e.to(device) batch_targets = batch_targets.to(device) batch_snorm_n = batch_snorm_n.to(device) # num x 1 optimizer.zero_grad() batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_snorm_n, batch_snorm_e) loss = model.loss(batch_scores, batch_targets) loss.backward() optimizer.step() epoch_loss += loss.detach().item() epoch_train_mae += MAE(batch_scores, batch_targets) nb_data += batch_targets.size(0) epoch_loss /= (iter + 1) epoch_train_mae /= (iter + 1) return epoch_loss, epoch_train_mae, optimizer def evaluate_network(model, device, data_loader, epoch): model.eval() epoch_test_loss = 0 epoch_test_mae = 0 nb_data = 0 with torch.no_grad(): for iter, (batch_graphs, batch_targets, batch_snorm_n, batch_snorm_e) in enumerate(data_loader): batch_x = batch_graphs.ndata['feat'].to(device) batch_e = batch_graphs.edata['feat'].to(device) batch_snorm_e = batch_snorm_e.to(device) batch_targets = batch_targets.to(device) batch_snorm_n = batch_snorm_n.to(device) batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_snorm_n, batch_snorm_e) loss = model.loss(batch_scores, batch_targets) epoch_test_loss += loss.detach().item() epoch_test_mae += MAE(batch_scores, batch_targets) nb_data += batch_targets.size(0) epoch_test_loss /= (iter + 1) epoch_test_mae /= (iter + 1) return epoch_test_loss, epoch_test_mae ================================================ FILE: realworld_benchmark/train/train_superpixels_graph_classification.py ================================================ # MIT License # Copyright (c) 2020 Vijay Prakash Dwivedi, Chaitanya K. Joshi, Thomas Laurent, Yoshua Bengio, Xavier Bresson """ Utility functions for training one epoch and evaluating one epoch """ import torch import torch.nn as nn import math from .metrics import accuracy_MNIST_CIFAR as accuracy def train_epoch(model, optimizer, device, data_loader, epoch): model.train() epoch_loss = 0 epoch_train_acc = 0 nb_data = 0 gpu_mem = 0 for iter, (batch_graphs, batch_labels, batch_snorm_n, batch_snorm_e) in enumerate(data_loader): batch_x = batch_graphs.ndata['feat'].to(device) # num x feat batch_e = batch_graphs.edata['feat'].to(device) batch_snorm_e = batch_snorm_e.to(device) batch_labels = batch_labels.to(device) batch_snorm_n = batch_snorm_n.to(device) # num x 1 optimizer.zero_grad() batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_snorm_n, batch_snorm_e) loss = model.loss(batch_scores, batch_labels) loss.backward() optimizer.step() epoch_loss += loss.detach().item() epoch_train_acc += accuracy(batch_scores, batch_labels) nb_data += batch_labels.size(0) epoch_loss /= (iter + 1) epoch_train_acc /= nb_data return epoch_loss, epoch_train_acc, optimizer def evaluate_network(model, device, data_loader, epoch): model.eval() epoch_test_loss = 0 epoch_test_acc = 0 nb_data = 0 with torch.no_grad(): for iter, (batch_graphs, batch_labels, batch_snorm_n, batch_snorm_e) in enumerate(data_loader): batch_x = batch_graphs.ndata['feat'].to(device) batch_e = batch_graphs.edata['feat'].to(device) batch_snorm_e = batch_snorm_e.to(device) batch_labels = batch_labels.to(device) batch_snorm_n = batch_snorm_n.to(device) batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_snorm_n, batch_snorm_e) loss = model.loss(batch_scores, batch_labels) epoch_test_loss += loss.detach().item() epoch_test_acc += accuracy(batch_scores, batch_labels) nb_data += batch_labels.size(0) epoch_test_loss /= (iter + 1) epoch_test_acc /= nb_data return epoch_test_loss, epoch_test_acc