Repository: HKUDS/AnyGraph Branch: main Commit: c2c5bfe103c4 Files: 15 Total size: 104.2 KB Directory structure: gitextract_t0oakqy3/ ├── .gitignore ├── History/ │ ├── pretrain_link1.his │ └── pretrain_link2.his ├── Models/ │ └── README.md ├── README.md ├── Utils/ │ └── TimeLogger.py ├── data_handler.py ├── main.py ├── model.py ├── node_classification/ │ ├── Utils/ │ │ └── TimeLogger.py │ ├── data_handler.py │ ├── main.py │ ├── model.py │ └── params.py └── params.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ extract_code_structure.py .DS_Store ================================================ FILE: Models/README.md ================================================ Download the pre-trained AnyGraph models from this link. ================================================ FILE: README.md ================================================

AnyGraph: Graph Foundation Model in the Wild

Lianghao Xia and Chao Huang **Introducing AnyGraph, a graph foundation model designed for zero-shot predictions across domains.**
**Objectives of AnyGraph:** - **Structure Heterogeneity**: Addressing distribution shift in graph structural information. - **Feature Heterogeneity**: Handling diverse feature representation spaces across graph datasets. - **Fast Adaptation**: Efficiently adapting the model to new graph domains. - **Scaling Law Emergence**: Performance scales with the amount of data and model parameters.
**Key Features of AnyGraph:** - **Graph Mixture-of-Experts (MoE)**: Effectively addresses cross-domain heterogeneity using an array of expert models. - **Lightweight Graph Expert Routing Mechanism**: Enables swift adaptation to new datasets and domains. - **Adaptive and Efficient Graph Experts**: Custom-designed to handle graphs with a wide range of structural patterns and feature spaces. - **Extensively Trained and Tested**: Exhibits strong generalizability over 38 diverse graph datasets, showcasing scaling laws and emergent capabilities. ## News - [x] 2024/09/18 Fix minor bugs on loading and saving path, class names, etc. - [x] 2024/09/11 Datasets are available now! - [x] 2024/09/11 Pre-trained models are available on huggingface now! - [x] 2024/08/20 Paper is released. - [x] 2024/08/20 Code is released. ## Environment Setup Download the data files at this link. And fill in your own directories for data storage at function `get_data_files(self)` of class `DataHandler` in the file `data_handler.py`. Download the pre-trained AnyGraph models at hugging face or one drive, and put it into `Models/`. **Packages**: Our experiments were conducted with the following package versions: * python==3.10.13 * torch==1.13.0 * numpy==1.23.4 * scipy==1.9.3 **Device Requirements**: The training and testing of AnyGraph requires only one GPU with 24G memory (e.g. 3090, 4090). Using larger input graphs may require devices with larger memory. ## Code Structure Here is a brief overview of the code structures. The explanations for each directory are enclosed in quotes (##...##). ``` ./ │ ├── .gitignore │ ├── README.md │ ├── data_handler.py ## load and process data │ ├── model.py ## implementation for the model │ ├── params.py ## hyperparameters │ └── main.py ## main file for pretraining and link prediction │ ├── History/ ## training and testing logs │ │ ├── pretrain_link1.his │ │ └── pretrain_link2.his │ ├── Models/ ## pre-trained models │ │ └── README.md │ ├── Utils/ ## utility function │ │ └── TimeLogger.py │ ├── imgs/ ## images used in readme │ │ ├── ablation.png │ │ ├── article cover.png │ │ ├── datasets.png │ │ ├── framework_final.jpeg │ │ ├── framework_final.pdf │ │ ├── framework_final.png │ │ ├── overall_performance2.png │ │ ├── routing.png │ │ ├── scaling_law.png │ │ ├── training_time.png │ │ ├── tuning_steps.png │ │ └── overall_performance1.png │ ├── node_classification/ ## test code for node classification │ │ ├── data_handler.py │ │ ├── model.py │ │ ├── params.py │ │ └── main.py │ │ ├── Utils/ │ │ │ └── TimeLogger.py ``` ## Usage To reproduce the test performance reported in the paper, run the following command lines: ``` # Test on Link2 and Link1 data, respectively python main.py --load pretrain_link1 --epoch 0 --dataset link2 python main.py --load pretrain_link2 --epoch 0 --dataset link1 # Test on the Ecommerce datasets in the Link2 and Link1 group, respectively. # Testing on Academic and Others datasets are conducted similarily. python main.py --load pretrain_link1 --epoch 0 --dataset ecommerce_in_link2 python main.py --load pretrain_link2 --epoch 0 --dataset ecommerce_in_link1 # Test the performance for node classification datasets cd ./node_classification python main.py --load pretrain_link2 --epoch 0 --dataset node ``` To re-train the two models by yourself, run: ``` python main.py --dataset link2+link1 --save pretrain_link2 python main.py --dataset link1+link2 --save pretrain_link1 ``` ## Datasets The statistics for the experimental datasets are presented in the table above. We categorize them into distinct groups as below. Note that Link1 and Link2 include datasets from different sources, and the datasets do not share the same feature spaces. This separation ensures a robust evaluation of true zero-shot performance in graph prediction tasks. | Group | Included Datasets | | ----- | ----- | | Link1 | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, P2P-Gnutella06, Soc-Epinions1, Email-Enron | | Link2 | Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla, Arxiv, Arxiv-t, Cora, CS, OGB-Collab, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA | | Ecommerce | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla | | Academic | Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, Arxiv, Arxiv-t, Cora, CS, OGB-Collab | | Others | P2P-Gnutella06, Soc-Epinions1, Email-Enron, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA | | Node | Cora, Arxiv, Pubmed, Home, Tech | ## Experiments ### Model Pre-Training Curves We present the training logs with respect to epochs below. Each figure contains two curves, each corresponding to two instances of repeated pre-training. - pretrain_link1     - pretrain_link2     ### Overall Performance Comparison - Comparing to few-shot end2end models and pre-training and fine-tuning methods. ![](imgs/overall_performance1.png) - Comparing to zero-shot graph foundation models. ### Scaling Law of AnyGraph We explore the scaling law of AnyGraph by evaluating 1) model performance v.s. the number of model parameters, and 2) model performance v.s. the number of training samples. Below shows the evaluation results on - all datasets across domains (a) - academic datasets (b) - ecommerce datasets (c) - other datasets (d) In each subfigure, we show - zero-shot performance on unseen datasets w.r.t. the amount of model parameters (left) - full-shot performance on training datasets w.r.t. the amount of model parameters (middle) - zero-shot performance w.r.t. the amount of training data (right) ![](imgs/scaling_law.png) The outcome outlines the following key observations: (see Sec. 4.3 for details) - Generalizability of AnyGraph Follows the Scaling Law. - Emergent Abilities of AnyGraph. - Insufficient Training Data May Bring Bias. ### Ablation Study The ablation study investigates the impact of the following modules: - The overall MoE architecture - Frequency regularization in the expert routing mechanism - Graph augmentation in the learning process - The utilization of (heterogeneous) node features from different datasets ### Expert Routing Mechanism We visualize the competence scores between datasets and experts, given by the routing algorithm of AnyGraph. The resulting scores below demonstrates the underlying relatedness between different datasets, thus demonstrating the intuitive effectiveness of the routing mechanism. (see Sec. 4.5 for details) ### Fast Adaptation of AnyGraph We study the fast adaptation abilities of AnyGraph from two aspects: - When fine-tuned on unseen datasets, AnyGraph achieves better performance with less training steps. (Fig. 6 below) - The training time of AnyGraph is comparative to that of other methods. (Table 3 below) ## Citation If you find our work useful, please consider citing our paper: ``` @article{xia2024anygraph, title={AnyGraph: Graph Foundation Model in the Wild}, author={Xia, Lianghao and Huang, Chao}, journal={arXiv preprint arXiv:2408.10700}, year={2024} } ``` ================================================ FILE: Utils/TimeLogger.py ================================================ import datetime logmsg = '' timemark = dict() saveDefault = False def log(msg, save=None, oneline=False): global logmsg global saveDefault time = datetime.datetime.now() tem = '%s: %s' % (time, msg) if save != None: if save: logmsg += tem + '\n' elif saveDefault: logmsg += tem + '\n' if oneline: print(tem, end='\r') else: print(tem) def marktime(marker): global timemark timemark[marker] = datetime.datetime.now() if __name__ == '__main__': log('') ================================================ FILE: data_handler.py ================================================ import pickle import numpy as np from scipy.sparse import csr_matrix, coo_matrix, dok_matrix from params import args import scipy.sparse as sp from Utils.TimeLogger import log import torch as t import torch.utils.data as data import torch_geometric.transforms as T from model import Feat_Projector, Adj_Projector, TopoEncoder import os class MultiDataHandler: def __init__(self, trn_datasets, tst_datasets_group): all_datasets = trn_datasets all_tst_datasets = [] for tst_datasets in tst_datasets_group: all_datasets = all_datasets + tst_datasets all_tst_datasets += tst_datasets all_datasets = list(set(all_datasets)) all_datasets.sort() self.trn_handlers = [] self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))] for data_name in all_datasets: trn_flag = data_name in trn_datasets tst_flag = data_name in all_tst_datasets handler = DataHandler(data_name) if trn_flag: self.trn_handlers.append(handler) if tst_flag: for i in range(len(tst_datasets_group)): if data_name in tst_datasets_group[i]: self.tst_handlers_group[i].append(handler) self.make_joint_trn_loader() def make_joint_trn_loader(self): loader_datasets = [] for trn_handler in self.trn_handlers: tem_dataset = trn_handler.trn_loader.dataset loader_datasets.append(tem_dataset) joint_dataset = JointTrnData(loader_datasets) self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0) def remake_initial_projections(self): for i in range(len(self.trn_handlers)): trn_handler = self.trn_handlers[i] trn_handler.make_projectors() class DataHandler: def __init__(self, data_name): self.data_name = data_name self.get_data_files() log(f'Loading dataset {data_name}') self.topo_encoder = TopoEncoder() self.load_data() def get_data_files(self): predir = f'/home/user_name/data/zero-shot datasets/{self.data_name}/' if os.path.exists(predir + 'feats.pkl'): self.feat_file = predir + 'feats.pkl' else: self.feat_file = None self.trnfile = predir + 'trn_mat.pkl' self.tstfile = predir + 'tst_mat.pkl' self.fewshotfile = predir + 'partial_mat_{shot}.pkl'.format(shot=args.ratio_fewshot) self.valfile = predir + 'val_mat.pkl' def load_one_file(self, filename): with open(filename, 'rb') as fs: ret = (pickle.load(fs) != 0).astype(np.float32) if type(ret) != coo_matrix: ret = sp.coo_matrix(ret) return ret def load_feats(self, filename): try: with open(filename, 'rb') as fs: feats = pickle.load(fs) except Exception as e: print(filename + str(e)) exit() return feats def normalize_adj(self, mat, log=False): degree = np.array(mat.sum(axis=-1)) d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1]) d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 d_inv_sqrt_mat = sp.diags(d_inv_sqrt) if mat.shape[0] == mat.shape[1]: return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo() else: tem = d_inv_sqrt_mat.dot(mat) col_degree = np.array(mat.sum(axis=0)) d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1]) d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 d_inv_sqrt_mat = sp.diags(d_inv_sqrt) return tem.dot(d_inv_sqrt_mat).tocoo() def unique_numpy(self, row, col): hash_vals = row * args.node_num + col hash_vals = np.unique(hash_vals).astype(np.int64) col = hash_vals % args.node_num row = (hash_vals - col).astype(np.int64) // args.node_num return row, col def make_torch_adj(self, mat, unidirectional_for_asym=False): if mat.shape[0] == mat.shape[1]: _row = mat.row _col = mat.col row = np.concatenate([_row, _col]).astype(np.int64) col = np.concatenate([_col, _row]).astype(np.int64) row, col = self.unique_numpy(row, col) data = np.ones_like(row) mat = coo_matrix((data, (row, col)), mat.shape) if args.selfloop == 1: mat = (mat + sp.eye(mat.shape[0])) * 1.0 normed_asym_mat = self.normalize_adj(mat) row = t.from_numpy(normed_asym_mat.row).long() col = t.from_numpy(normed_asym_mat.col).long() idxs = t.stack([row, col], dim=0) vals = t.from_numpy(normed_asym_mat.data).float() shape = t.Size(normed_asym_mat.shape) asym_adj = t.sparse.FloatTensor(idxs, vals, shape) return asym_adj elif unidirectional_for_asym: mat = (mat != 0) * 1.0 mat = self.normalize_adj(mat, log=True) idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) vals = t.from_numpy(mat.data.astype(np.float32)) shape = t.Size(mat.shape) return t.sparse.FloatTensor(idxs, vals, shape) else: # make ui adj a = sp.csr_matrix((args.user_num, args.user_num)) b = sp.csr_matrix((args.item_num, args.item_num)) mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])]) mat = (mat != 0) * 1.0 if args.selfloop == 1: mat = (mat + sp.eye(mat.shape[0])) * 1.0 mat = self.normalize_adj(mat) # make cuda tensor idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) vals = t.from_numpy(mat.data.astype(np.float32)) shape = t.Size(mat.shape) return t.sparse.FloatTensor(idxs, vals, shape) def load_data(self): tst_mat = self.load_one_file(self.tstfile) val_mat = self.load_one_file(self.valfile) trn_mat = self.load_one_file(self.trnfile) fewshot_mat = self.load_one_file(self.fewshotfile) if self.feat_file is not None: self.feats = t.from_numpy(self.load_feats(self.feat_file)).float() self.feats = self.feats args.featdim = self.feats.shape[1] else: self.feats = None args.featdim = args.latdim if trn_mat.shape[0] != trn_mat.shape[1]: args.user_num, args.item_num = trn_mat.shape args.node_num = args.user_num + args.item_num print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz)) else: args.node_num = trn_mat.shape[0] print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+val_mat.nnz+tst_mat.nnz)) if args.tst_mode == 'tst': tst_data = TstData(tst_mat, trn_mat) self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0) self.tst_input_adj = self.make_torch_adj(trn_mat) elif args.tst_mode == 'val': tst_data = TstData(val_mat, trn_mat) self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0) self.tst_input_adj = self.make_torch_adj(fewshot_mat) else: raise Exception('Specify proper test mode') if args.trn_mode == 'fewshot': self.trn_mat = fewshot_mat trn_data = TrnData(self.trn_mat) self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0) self.trn_input_adj = self.make_torch_adj(fewshot_mat) if args.tst_mode == 'val': self.trn_input_adj = self.tst_input_adj else: self.trn_input_adj = self.make_torch_adj(fewshot_mat) elif args.trn_mode == 'train-all': self.trn_mat = trn_mat trn_data = TrnData(self.trn_mat) self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0) if args.tst_mode == 'tst': self.trn_input_adj = self.tst_input_adj else: self.trn_input_adj = self.make_torch_adj(trn_mat) else: raise Exception('Specify proper train mode') if self.trn_mat.shape[0] == self.trn_mat.shape[1]: self.asym_adj = self.trn_input_adj else: self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True) self.make_projectors() self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps) self.ratio_500_all = 500 / len(self.trn_loader) def make_projectors(self): with t.no_grad(): projectors = [] if args.proj_method == 'adj_svd' or args.proj_method == 'both': tem = self.asym_adj.to(args.devices[0]) projectors = [Adj_Projector(tem)] if self.feats is not None and args.proj_method != 'adj_svd': tem = self.feats.to(args.devices[0]) projectors.append(Feat_Projector(tem)) assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot' feats = projectors[0]() if len(projectors) == 2: feats2 = projectors[1]() feats = feats + feats2 try: self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu() except Exception: print(f'{self.data_name} memory overflow') mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True) tem_adj = self.trn_input_adj.to(args.devices[0]) mem_cache = 256 projectors_list = [] for i in range(feats.shape[1] // mem_cache): st, ed = i * mem_cache, (i + 1) * mem_cache tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8) tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu() projectors_list.append(tem_feats) self.projectors = t.concat(projectors_list, dim=-1) t.cuda.empty_cache() class TstData(data.Dataset): def __init__(self, coomat, trn_mat): self.csrmat = (trn_mat.tocsr() != 0) * 1.0 tstLocs = [None] * coomat.shape[0] tst_nodes = set() for i in range(len(coomat.data)): row = coomat.row[i] col = coomat.col[i] if tstLocs[row] is None: tstLocs[row] = list() tstLocs[row].append(col) tst_nodes.add(row) tst_nodes = np.array(list(tst_nodes)) self.tst_nodes = tst_nodes self.tstLocs = tstLocs def __len__(self): return len(self.tst_nodes) def __getitem__(self, idx): return self.tst_nodes[idx] class TrnData(data.Dataset): def __init__(self, coomat): self.ancs, self.poss = coomat.row, coomat.col self.negs = np.zeros(len(self.ancs)).astype(np.int32) self.cand_num = coomat.shape[1] self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0] self.poss = coomat.col + self.neg_shift self.neg_sampling() def neg_sampling(self): self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0]) def __len__(self): return len(self.ancs) def __getitem__(self, idx): return self.ancs[idx], self.poss[idx] , self.negs[idx] class JointTrnData(data.Dataset): def __init__(self, dataset_list): self.batch_dataset_ids = [] self.batch_st_ed_list = [] self.dataset_list = dataset_list for dataset_id, dataset in enumerate(dataset_list): samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0) for j in range(samp_num): self.batch_dataset_ids.append(dataset_id) st = j * args.batch ed = min((j + 1) * args.batch, len(dataset)) self.batch_st_ed_list.append((st, ed)) def neg_sampling(self): for dataset in self.dataset_list: dataset.neg_sampling() def __len__(self): return len(self.batch_dataset_ids) def __getitem__(self, idx): st, ed = self.batch_st_ed_list[idx] dataset_id = self.batch_dataset_ids[idx] return *self.dataset_list[dataset_id][st: ed], dataset_id ================================================ FILE: main.py ================================================ import torch as t from torch import nn import Utils.TimeLogger as logger from Utils.TimeLogger import log from params import args from model import Expert, Feat_Projector, Adj_Projector, AnyGraph from data_handler import MultiDataHandler, DataHandler import numpy as np import pickle import os import setproctitle import time class Exp: def __init__(self, multi_handler): self.multi_handler = multi_handler print(list(map(lambda x: x.data_name, multi_handler.trn_handlers))) for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group): print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers))) self.metrics = dict() trn_mets = ['Loss', 'preLoss'] tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss'] mets = trn_mets + tst_mets for met in mets: if met in trn_mets: self.metrics['Train' + met] = list() if met in tst_mets: for i in range(len(self.multi_handler.tst_handlers_group)): self.metrics['Test' + str(i) + met] = list() def make_print(self, name, ep, reses, save, data_name=None): if data_name is None: ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name) else: ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name) for metric in reses: val = reses[metric] ret += '%s = %.4f, ' % (metric, val) tem = name + metric if data_name is None else name + data_name + metric if save and tem in self.metrics: self.metrics[tem].append(val) ret = ret[:-2] + ' ' return ret def run(self): self.prepare_model() log('Model Prepared') stloc = 0 if args.load_model != None: self.load_model() stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1) best_ndcg, best_ep = 0, -1 for ep in range(stloc, args.epoch): tst_flag = (ep % args.tst_epoch == 0) start_time = time.time() self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True) reses = self.train_epoch() log(self.make_print('Train', ep, reses, tst_flag)) self.multi_handler.remake_initial_projections() end_time = time.time() print(f'NOTICE: {end_time-start_time}') if tst_flag: for handler_group_id in range(len(self.multi_handler.tst_handlers_group)): tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id] self.model.assign_experts(tst_handlers, reca=False, log_assignment=True) recall, ndcg, tstnum = 0, 0, 0 for i, handler in enumerate(tst_handlers): reses = self.test_epoch(handler, i) # log(self.make_print(f'{handler.data_name}', ep, reses, False)) recall += reses['Recall'] * reses['tstNum'] ndcg += reses['NDCG'] * reses['tstNum'] tstnum += reses['tstNum'] reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum} log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag)) if reses['NDCG'] > best_ndcg: best_ndcg = reses['NDCG'] best_ep = ep self.save_history() print() for test_group_id in range(len(self.multi_handler.tst_handlers_group)): repeat_times = 5 overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times) overall_tstnum = 0 tst_handlers = self.multi_handler.tst_handlers_group[test_group_id] for i, handler in enumerate(tst_handlers): for topk in [args.topk]: args.topk = topk mets = dict() for _ in range(repeat_times): handler.make_projectors() self.model.assign_experts([handler], reca=False, log_assignment=False) reses = self.test_epoch(handler, 0) for met in reses: if met not in mets: mets[met] = [] mets[met].append(reses[met]) tstnum = reses['tstNum'] tot_reses = dict() for met in reses: tem_arr = np.array(mets[met]) tot_reses[met + '_std'] = tem_arr.std() tot_reses[met + '_mean'] = tem_arr.mean() if topk == args.topk: overall_recall += np.array(mets['Recall']) * tstnum overall_ndcg += np.array(mets['NDCG']) * tstnum overall_tstnum += tstnum log(self.make_print(f'Test Top-{topk}', args.epoch, tot_reses, False, handler.data_name)) overall_recall /= overall_tstnum overall_ndcg /= overall_tstnum overall_res = dict() overall_res['Recall_mean'] = overall_recall.mean() overall_res['Recall_std'] = overall_recall.std() overall_res['NDCG_mean'] = overall_ndcg.mean() overall_res['NDCG_std'] = overall_ndcg.std() log(self.make_print('Overall Test', args.epoch, overall_res, False)) self.save_history() def print_model_size(self): total_params = 0 trainable_params = 0 non_trainable_params = 0 for param in self.model.parameters(): tem = np.prod(param.size()) total_params += tem if param.requires_grad: trainable_params += tem else: non_trainable_params += tem print(f'Total params: {total_params/1e6}') print(f'Trainable params: {trainable_params/1e6}') print(f'Non-trainable params: {non_trainable_params/1e6}') def prepare_model(self): self.model = AnyGraph() t.cuda.empty_cache() self.print_model_size() def train_epoch(self): self.model.train() trn_loader = self.multi_handler.joint_trn_loader trn_loader.dataset.neg_sampling() ep_loss, ep_preloss, ep_regloss = 0, 0, 0 steps = len(trn_loader) tot_samp_num = 0 counter = [0] * len(self.multi_handler.trn_handlers) reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers))) for i, batch_data in enumerate(trn_loader): if args.epoch_max_step > 0 and i >= args.epoch_max_step: break ancs, poss, negs, dataset_id = batch_data ancs = ancs[0].long() poss = poss[0].long() negs = negs[0].long() dataset_id = dataset_id[0].long() tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all if tem_bar < 1.0 and np.random.uniform() > tem_bar: steps -= 1 continue expert = self.model.summon(dataset_id) opt = self.model.summon_opt(dataset_id) feats = self.multi_handler.trn_handlers[dataset_id].projectors loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats) opt.zero_grad() loss.backward() opt.step() sample_num = ancs.shape[0] tot_samp_num += sample_num ep_loss += loss.item() * sample_num ep_preloss += loss_dict['preloss'].item() * sample_num ep_regloss += loss_dict['regloss'].item() log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True) counter[dataset_id] += 1 if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0: self.multi_handler.trn_handlers[dataset_id].make_projectors() if (i + 1) % reassign_steps == 0: self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False) ret = dict() ret['Loss'] = ep_loss / tot_samp_num ret['preLoss'] = ep_preloss / tot_samp_num ret['regLoss'] = ep_regloss / steps t.cuda.empty_cache() return ret def make_trn_masks(self, numpy_usrs, csr_mat): trn_masks = csr_mat[numpy_usrs].tocoo() cand_size = trn_masks.shape[1] trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long() return trn_masks, cand_size def test_epoch(self, handler, dataset_id): with t.no_grad(): tst_loader = handler.tst_loader self.model.eval() expert = self.model.summon(dataset_id) ep_recall, ep_ndcg = 0, 0 ep_tstnum = len(tst_loader.dataset) steps = max(ep_tstnum // args.tst_batch, 1) for i, batch_data in enumerate(tst_loader): if args.tst_steps != -1 and i > args.tst_steps: break usrs = batch_data.long() trn_masks, cand_size = self.make_trn_masks(batch_data.numpy(), tst_loader.dataset.csrmat) feats = handler.projectors all_preds = expert.pred_for_test((usrs, trn_masks), cand_size, feats, rerun_embed=False if i!=0 else True) _, top_locs = t.topk(all_preds, args.topk) top_locs = top_locs.cpu().numpy() recall, ndcg = self.calc_recall_ndcg(top_locs, tst_loader.dataset.tstLocs, usrs) ep_recall += recall ep_ndcg += ndcg log('Steps %d/%d: recall = %.2f, ndcg = %.2f ' % (i, steps, recall, ndcg), save=False, oneline=True) ret = dict() if args.tst_steps != -1: ep_tstnum = args.tst_steps * args.tst_batch ret['Recall'] = ep_recall / ep_tstnum ret['NDCG'] = ep_ndcg / ep_tstnum ret['tstNum'] = ep_tstnum t.cuda.empty_cache() return ret def calc_recall_ndcg(self, topLocs, tstLocs, batIds): assert topLocs.shape[0] == len(batIds) allRecall = allNdcg = 0 for i in range(len(batIds)): temTopLocs = list(topLocs[i]) temTstLocs = tstLocs[batIds[i]] tstNum = len(temTstLocs) maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))]) recall = dcg = 0 for val in temTstLocs: if val in temTopLocs: recall += 1 dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2)) recall = recall / tstNum ndcg = dcg / maxDcg allRecall += recall allNdcg += ndcg return allRecall, allNdcg def save_history(self): if args.epoch == 0: return with open('./History/' + args.save_path + '.his', 'wb') as fs: pickle.dump(self.metrics, fs) content = { 'model': self.model, } t.save(content, './Models/' + args.save_path + '.mod') log('Model Saved: %s' % args.save_path) def load_model(self): ckp = t.load('./Models/' + args.load_model + '.mod') self.model = ckp['model'] self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0) with open('./History/' + args.load_model + '.his', 'rb') as fs: self.metrics = pickle.load(fs) log('Model Loaded') if __name__ == '__main__': os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if len(args.gpu.split(',')) == 2: args.devices = ['cuda:0', 'cuda:1'] elif len(args.gpu.split(',')) > 2: raise Exception('Devices should be less than 2') else: args.devices = ['cuda:0', 'cuda:0'] logger.saveDefault = True setproctitle.setproctitle('AnyGraph') log('Start') datasets = dict() datasets['all'] = [ 'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1' ] datasets['ecommerce'] = [ 'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech' ] datasets['academic'] = [ 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab' ] datasets['others'] = [ 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1' ] datasets['link1'] = [ 'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron', ] datasets['link2'] = [ 'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA', ] if args.dataset_setting in datasets.keys(): trn_datasets = tst_datasets = datasets[args.dataset_setting] elif args.dataset_setting in datasets['all']: trn_datasets = tst_datasets = [args.dataset_setting] elif '+' in args.dataset_setting: idx = args.dataset_setting.index('+') trn_datasets = datasets[args.dataset_setting[:idx]] tst_datasets = datasets[args.dataset_setting[idx+1:]] elif '_in_' in args.dataset_setting: idx = args.dataset_setting.index('_in_') tst_datasets_1 = datasets[args.dataset_setting[:idx]] tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]] tst_datasets = [] for data in tst_datasets_1: if data in tst_datasets_2: tst_datasets.append(data) trn_datasets = tst_datasets if '+' not in args.dataset_setting: # No zero-shot prediction test handler = MultiDataHandler(trn_datasets, [tst_datasets]) else: handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets]) log('Load Data') exp = Exp(handler) exp.run() print(args.load_model, args.dataset_setting) ================================================ FILE: model.py ================================================ import torch as t from torch import nn import torch.nn.functional as F from params import args import numpy as np from Utils.TimeLogger import log from torch.nn import MultiheadAttention from time import time init = nn.init.xavier_uniform_ uniformInit = nn.init.uniform_ class FeedForwardLayer(nn.Module): def __init__(self, in_feat, out_feat, bias=True, act=None): super(FeedForwardLayer, self).__init__() self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16) if act == 'identity' or act is None: self.act = None elif act == 'leaky': self.act = nn.LeakyReLU(negative_slope=args.leaky) elif act == 'relu': self.act = nn.ReLU() elif act == 'relu6': self.act = nn.ReLU6() else: raise Exception('Error') def forward(self, embeds): if self.act is None: return self.linear(embeds) return self.act(self.linear(embeds)) class TopoEncoder(nn.Module): def __init__(self): super(TopoEncoder, self).__init__() self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False) def forward(self, adj, embeds, normed=False): with t.no_grad(): if not normed: embeds = self.layer_norm(embeds) # embeds_list = [] final_embeds = 0 if args.gnn_layer == 0: final_embeds = embeds # embeds_list.append(embeds) for _ in range(args.gnn_layer): embeds = t.spmm(adj, embeds) final_embeds += embeds # embeds_list.append(embeds) embeds = final_embeds#sum(embeds_list) return embeds class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)]) self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)]) self.dropout = nn.Dropout(p=args.drop_rate) def forward(self, embeds): for i in range(args.fc_layer): embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds) return embeds class GTLayer(nn.Module): def __init__(self): super(GTLayer, self).__init__() self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16) self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) self.fc_dropout = nn.Dropout(p=args.drop_rate) def _pick_anchors(self, embeds): perm = t.randperm(embeds.shape[0]) anchors = perm[:args.anchor] return embeds[anchors] def forward(self, embeds): anchor_embeds = self._pick_anchors(embeds) _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds) anchor_embeds = _anchor_embeds + anchor_embeds _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False) embeds = self.layer_norm1(_embeds + embeds) _embeds = self.fc_dropout(self.dense_layers(embeds)) embeds = (self.layer_norm2(_embeds + embeds)) return embeds class GraphTransformer(nn.Module): def __init__(self): super(GraphTransformer, self).__init__() self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)]) def forward(self, embeds): for i, layer in enumerate(self.gt_layers): embeds = layer(embeds) / args.scale_layer return embeds class Feat_Projector(nn.Module): def __init__(self, feats): super(Feat_Projector, self).__init__() if args.proj_method == 'uniform': self.proj_embeds = self.uniform_proj(feats) elif args.proj_method == 'svd' or args.proj_method == 'both': self.proj_embeds = self.svd_proj(feats) elif args.proj_method == 'random': self.proj_embeds = self.random_proj(feats) elif args.proj_method == 'original': self.proj_embeds = feats self.proj_embeds = t.flip(self.proj_embeds, dims=[-1]) self.proj_embeds = self.proj_embeds.detach() def svd_proj(self, feats): if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]: dim = min(feats.shape[0], feats.shape[1]) decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter) decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])]) else: decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter) decom_feats = decom_feats @ t.diag(t.sqrt(s)) return decom_feats.cpu() def uniform_proj(self, feats): projection = init(t.empty(args.featdim, args.latdim)) return feats @ projection def random_proj(self, feats): projection = init(t.empty(feats.shape[0], args.latdim)) return projection def forward(self): return self.proj_embeds class Adj_Projector(nn.Module): def __init__(self, adj): super(Adj_Projector, self).__init__() if args.proj_method == 'adj_svd' or args.proj_method == 'both': self.proj_embeds = self.svd_proj(adj) self.proj_embeds = self.proj_embeds.detach() def svd_proj(self, adj): q = args.latdim if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]: dim = min(adj.shape[0], adj.shape[1]) svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter) svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])]) else: svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter) svd_u = svd_u @ t.diag(t.sqrt(s)) svd_v = svd_v @ t.diag(t.sqrt(s)) if adj.shape[0] != adj.shape[1]: projection = t.concat([svd_u, svd_v], dim=0) else: projection = svd_u + svd_v return projection.cpu() def forward(self): return self.proj_embeds class Expert(nn.Module): def __init__(self): super(Expert, self).__init__() self.topo_encoder = TopoEncoder().to(args.devices[0]) if args.nn == 'mlp': self.trainable_nn = MLP().to(args.devices[1]) else: self.trainable_nn = GraphTransformer().to(args.devices[1]) self.trn_count = 1 def forward(self, projectors, pck_nodes=None): embeds = projectors.to(args.devices[1]) if pck_nodes is not None: embeds = embeds[pck_nodes] embeds = self.trainable_nn(embeds) return embeds def pred_norm(self, pos_preds, neg_preds): pos_preds_num = pos_preds.shape[0] neg_preds_shape = neg_preds.shape preds = t.concat([pos_preds, neg_preds.view(-1)]) preds = preds - preds.max() pos_preds = preds[:pos_preds_num] neg_preds = preds[pos_preds_num:].view(neg_preds_shape) return pos_preds, neg_preds def cal_loss(self, batch_data, projectors): ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data)) self.trn_count += ancs.shape[0] pck_nodes = t.concat([ancs, poss, negs]) final_embeds = self.forward(projectors, pck_nodes) # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs] anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3) if final_embeds.isinf().any() or final_embeds.isnan().any(): raise Exception('Final embedding fails') if args.loss == 'ce': pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T) if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any(): raise Exception('Preds fails') pos_loss = pos_preds neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log() pre_loss = -(pos_loss - neg_loss).mean() elif args.loss == 'bpr': pos_preds = (anc_embeds * pos_embeds).sum(-1) neg_preds = (anc_embeds * neg_embeds).sum(-1) pos_loss, neg_loss = pos_preds, neg_preds pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean() if t.isinf(pre_loss).any() or t.isnan(pre_loss).any(): raise Exception('NaN or Inf') reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters()))) loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()} return pre_loss + reg_loss, loss_dict def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True): ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data)) if rerun_embed: try: final_embeds = self.forward(projectors) except Exception: final_embeds_list = [] div = args.batch * 3 temlen = projectors.shape[0] // div for i in range(temlen): st, ed = div * i, div * (i + 1) tem_projectors = projectors[st: ed, :] final_embeds_list.append(self.forward(tem_projectors)) if temlen * div < projectors.shape[0]: tem_projectors = projectors[temlen*div:, :] final_embeds_list.append(self.forward(tem_projectors)) final_embeds = t.concat(final_embeds_list, dim=0) self.final_embeds = final_embeds final_embeds = self.final_embeds anc_embeds = final_embeds[ancs] cand_embeds = final_embeds[-cand_size:] mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size])) dense_mat = mask_mat.to_dense() all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8 return all_preds def attempt(self, topo_embeds, dataset): final_embeds = self.trainable_nn(topo_embeds) rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs])) if rows.shape[0] > args.attempt_cache: random_perm = t.randperm(rows.shape[0], device=args.devices[0]) pck_perm = random_perm[:args.attempt_cache] rows = rows[pck_perm] cols = cols[pck_perm] negs = negs[pck_perm] while True: try: row_embeds = final_embeds[rows] col_embeds = final_embeds[cols] neg_embeds = final_embeds[negs] score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item() break except Exception: args.attempt_cache = args.attempt_cache // 2 random_perm = t.randperm(rows.shape[0], device=args.devices[0]) pck_perm = random_perm[:args.attempt_cache] rows = rows[pck_perm] cols = cols[pck_perm] negs = negs[pck_perm] t.cuda.empty_cache() return score class AnyGraph(nn.Module): def __init__(self): super(AnyGraph, self).__init__() self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda() self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts)) def assign_experts(self, handlers, reca=True, log_assignment=False): if args.expert_num == 1: self.assignment = [0] * len(handlers) return try: expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts))) expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2 except Exception: expert_scores = np.ones(len(self.experts)) with t.no_grad(): assignment = [list() for i in range(len(handlers))] for dataset_id, handler in enumerate(handlers): topo_embeds = handler.projectors.to(args.devices[1]) for expert_id, expert in enumerate(self.experts): expert = expert.to(args.devices[1]) score = expert.attempt(topo_embeds, handler.trn_loader.dataset) if reca: score *= expert_scores[expert_id] assignment[dataset_id].append((expert_id, score)) assignment[dataset_id].sort(key=lambda x: x[1], reverse=True) if log_assignment: print('\n----------\nAssignment') for dataset_id, handler in enumerate(handlers): out = '' for exp_idx in range(min(4, len(self.experts))): out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) ' print(handler.data_name, out) print('----------\n') self.assignment = list(map(lambda x: x[0][0], assignment)) def summon(self, dataset_id): return self.experts[self.assignment[dataset_id]] def summon_opt(self, dataset_id): return self.opts[self.assignment[dataset_id]] ================================================ FILE: node_classification/Utils/TimeLogger.py ================================================ import datetime logmsg = '' timemark = dict() saveDefault = False def log(msg, save=None, oneline=False): global logmsg global saveDefault time = datetime.datetime.now() tem = '%s: %s' % (time, msg) if save != None: if save: logmsg += tem + '\n' elif saveDefault: logmsg += tem + '\n' if oneline: print(tem, end='\r') else: print(tem) def marktime(marker): global timemark timemark[marker] = datetime.datetime.now() if __name__ == '__main__': log('') ================================================ FILE: node_classification/data_handler.py ================================================ import pickle import numpy as np from scipy.sparse import csr_matrix, coo_matrix, dok_matrix from params import args import scipy.sparse as sp from Utils.TimeLogger import log import torch as t import torch.utils.data as data import torch_geometric.transforms as T from model import Feat_Projector, Adj_Projector, TopoEncoder import os class MultiDataHandler: def __init__(self, trn_datasets, tst_datasets_group): all_datasets = trn_datasets all_tst_datasets = [] for tst_datasets in tst_datasets_group: all_datasets = all_datasets + tst_datasets all_tst_datasets += tst_datasets all_datasets = list(set(all_datasets)) all_datasets.sort() self.trn_handlers = [] self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))] for data_name in all_datasets: trn_flag = data_name in trn_datasets tst_flag = data_name in all_tst_datasets handler = DataHandler(data_name) if trn_flag: self.trn_handlers.append(handler) if tst_flag: for i in range(len(tst_datasets_group)): if data_name in tst_datasets_group[i]: self.tst_handlers_group[i].append(handler) self.make_joint_trn_loader() def make_joint_trn_loader(self): loader_datasets = [] for trn_handler in self.trn_handlers: tem_dataset = trn_handler.trn_loader.dataset loader_datasets.append(tem_dataset) joint_dataset = JointTrnData(loader_datasets) self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0) def remake_initial_projections(self): for i in range(len(self.trn_handlers)): trn_handler = self.trn_handlers[i] trn_handler.make_projectors() class DataHandler: def __init__(self, data_name): self.data_name = data_name self.get_data_files() log(f'Loading dataset {data_name}') self.topo_encoder = TopoEncoder() self.load_data() def get_data_files(self): predir = f'/home/user_name/data/zero-shot datasets/node_data/{self.data_name}/' # predir = f'../handle_node_data/{self.data_name}/' if os.path.exists(predir + 'feats.pkl'): self.feat_file = predir + 'feats.pkl' else: self.feat_file = None self.trnfile = predir + 'trn_mat.pkl' self.tstfile = predir + 'tst_mat.pkl' self.fewshotfile = predir + 'fewshot_mat_{shot}.pkl'.format(shot=args.shot) self.valfile = predir + 'val_mat.pkl' def load_one_file(self, filename): with open(filename, 'rb') as fs: ret = (pickle.load(fs) != 0).astype(np.float32) if type(ret) != coo_matrix: ret = sp.coo_matrix(ret) return ret def load_feats(self, filename): try: with open(filename, 'rb') as fs: feats = pickle.load(fs) except Exception as e: print(filename + str(e)) exit() return feats def normalize_adj(self, mat, log=False): degree = np.array(mat.sum(axis=-1)) d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1]) d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 d_inv_sqrt_mat = sp.diags(d_inv_sqrt) if mat.shape[0] == mat.shape[1]: return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo() else: tem = d_inv_sqrt_mat.dot(mat) col_degree = np.array(mat.sum(axis=0)) d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1]) d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 d_inv_sqrt_mat = sp.diags(d_inv_sqrt) return tem.dot(d_inv_sqrt_mat).tocoo() def unique_numpy(self, row, col): hash_vals = row * args.node_num + col hash_vals = np.unique(hash_vals).astype(np.int64) col = hash_vals % args.node_num row = (hash_vals - col).astype(np.int64) // args.node_num return row, col def make_torch_adj(self, mat, unidirectional_for_asym=False): if mat.shape[0] == mat.shape[1]: _row = mat.row _col = mat.col row = np.concatenate([_row, _col]).astype(np.int64) col = np.concatenate([_col, _row]).astype(np.int64) row, col = self.unique_numpy(row, col) data = np.ones_like(row) mat = coo_matrix((data, (row, col)), mat.shape) if args.selfloop == 1: mat = (mat + sp.eye(mat.shape[0])) * 1.0 normed_asym_mat = self.normalize_adj(mat) row = t.from_numpy(normed_asym_mat.row).long() col = t.from_numpy(normed_asym_mat.col).long() idxs = t.stack([row, col], dim=0) vals = t.from_numpy(normed_asym_mat.data).float() shape = t.Size(normed_asym_mat.shape) asym_adj = t.sparse.FloatTensor(idxs, vals, shape) return asym_adj elif unidirectional_for_asym: mat = (mat != 0) * 1.0 mat = self.normalize_adj(mat, log=True) idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) vals = t.from_numpy(mat.data.astype(np.float32)) shape = t.Size(mat.shape) return t.sparse.FloatTensor(idxs, vals, shape) else: # make ui adj a = sp.csr_matrix((args.user_num, args.user_num)) b = sp.csr_matrix((args.item_num, args.item_num)) mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])]) mat = (mat != 0) * 1.0 if args.selfloop == 1: mat = (mat + sp.eye(mat.shape[0])) * 1.0 mat = self.normalize_adj(mat) # make cuda tensor idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) vals = t.from_numpy(mat.data.astype(np.float32)) shape = t.Size(mat.shape) return t.sparse.FloatTensor(idxs, vals, shape) def load_data(self): tst_mat = self.load_one_file(self.tstfile) trn_mat = self.load_one_file(self.trnfile) # fewshot_mat = self.load_one_file(self.fewshotfile) if self.feat_file is not None: self.feats = t.from_numpy(self.load_feats(self.feat_file)).float() self.feats = self.feats args.featdim = self.feats.shape[1] else: self.feats = None args.featdim = args.latdim if trn_mat.shape[0] != trn_mat.shape[1]: args.user_num, args.item_num = trn_mat.shape args.node_num = args.user_num + args.item_num print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz)) else: args.node_num = trn_mat.shape[0] print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+tst_mat.nnz)) if args.tst_mode == 'tst': tst_data = NodeTstData(tst_mat) self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0) # tst_loss_data = TrnData(tst_mat) # self.tst_loss_loader = data.DataLoader(tst_loss_data, batch_size=args.batch, shuffle=False, num_workers=0) self.tst_input_adj = self.make_torch_adj(trn_mat) elif args.tst_mode == 'val': raise Exception('Validation not made for node classification') else: raise Exception('Specify proper test mode') self.trn_mat = trn_mat trn_data = TrnData(self.trn_mat) self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0) if args.tst_mode == 'tst': self.trn_input_adj = self.tst_input_adj else: self.trn_input_adj = self.make_torch_adj(trn_mat) if self.trn_mat.shape[0] == self.trn_mat.shape[1]: self.asym_adj = self.trn_input_adj else: self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True) self.make_projectors() self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps) self.ratio_500_all = 250 / len(self.trn_loader) def make_projectors(self): with t.no_grad(): projectors = [] if args.proj_method == 'adj_svd' or args.proj_method == 'both': tem = self.asym_adj.to(args.devices[0]) projectors = [Adj_Projector(tem)] if self.feats is not None and args.proj_method != 'adj_svd': tem = self.feats.to(args.devices[0]) projectors.append(Feat_Projector(tem)) assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot' feats = projectors[0]() if len(projectors) == 2: feats2 = projectors[1]() feats = feats + feats2 try: self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu() except Exception: print(f'{self.data_name} memory overflow') mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True) tem_adj = self.trn_input_adj.to(args.devices[0]) mem_cache = 256 projectors_list = [] for i in range(feats.shape[1] // mem_cache): st, ed = i * mem_cache, (i + 1) * mem_cache tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8) tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu() projectors_list.append(tem_feats) self.projectors = t.concat(projectors_list, dim=-1) t.cuda.empty_cache() class TrnData(data.Dataset): def __init__(self, coomat): self.ancs, self.poss = coomat.row, coomat.col # self.dokmat = set(list(map(lambda idx: (self.ancs[idx], self.poss[idx]), range(len(self.ancs))))) # self.dokmat = coomat.todok() self.negs = np.zeros(len(self.ancs)).astype(np.int32) self.cand_num = coomat.shape[1] self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0] self.poss = coomat.col + self.neg_shift self.neg_sampling() def neg_sampling(self): self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0]) # self.negs = np.zeros_like(self.ancs) # for i in range(len(self.ancs)): # u = self.ancs[i] # while True: # i_neg = np.random.randint(self.cand_num) # if (u, i_neg) not in self.dokmat: # break # self.negs[i] = i_neg # self.negs += self.neg_shift def __len__(self): return len(self.ancs) def __getitem__(self, idx): return self.ancs[idx], self.poss[idx] , self.negs[idx] class NodeTstData(data.Dataset): def __init__(self, tst_mat): self.class_num = tst_mat.shape[1] self.nodes, self.labels = tst_mat.row, tst_mat.col def __len__(self): return len(self.nodes) def __getitem__(self, idx): return self.nodes[idx], self.labels[idx] class JointTrnData(data.Dataset): def __init__(self, dataset_list): self.batch_dataset_ids = [] self.batch_st_ed_list = [] self.dataset_list = dataset_list for dataset_id, dataset in enumerate(dataset_list): samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0) for j in range(samp_num): self.batch_dataset_ids.append(dataset_id) st = j * args.batch ed = min((j + 1) * args.batch, len(dataset)) self.batch_st_ed_list.append((st, ed)) def neg_sampling(self): for dataset in self.dataset_list: dataset.neg_sampling() def __len__(self): return len(self.batch_dataset_ids) def __getitem__(self, idx): st, ed = self.batch_st_ed_list[idx] dataset_id = self.batch_dataset_ids[idx] return *self.dataset_list[dataset_id][st: ed], dataset_id ================================================ FILE: node_classification/main.py ================================================ import torch as t from torch import nn import Utils.TimeLogger as logger from Utils.TimeLogger import log from params import args from model import Expert, Feat_Projector, Adj_Projector, AnyGraph from data_handler import MultiDataHandler, DataHandler import numpy as np import pickle import os import setproctitle import time from sklearn.metrics import f1_score class Exp: def __init__(self, multi_handler): self.multi_handler = multi_handler print(list(map(lambda x: x.data_name, multi_handler.trn_handlers))) for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group): print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers))) self.metrics = dict() trn_mets = ['Loss', 'preLoss'] tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss'] mets = trn_mets + tst_mets for met in mets: if met in trn_mets: self.metrics['Train' + met] = list() if met in tst_mets: for i in range(len(self.multi_handler.tst_handlers_group)): self.metrics['Test' + str(i) + met] = list() def make_print(self, name, ep, reses, save, data_name=None): if data_name is None: ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name) else: ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name) for metric in reses: val = reses[metric] ret += '%s = %.4f, ' % (metric, val) tem = name + metric if data_name is None else name + data_name + metric if save and tem in self.metrics: self.metrics[tem].append(val) ret = ret[:-2] + ' ' return ret def run(self): self.prepare_model() log('Model Prepared') stloc = 0 if args.load_model != None: self.load_model() stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1) best_ndcg, best_ep = 0, -1 early_stop = False for ep in range(stloc, args.epoch): tst_flag = (ep % args.tst_epoch == 0) start_time = time.time() self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True) reses = self.train_epoch() log(self.make_print('Train', ep, reses, tst_flag)) self.multi_handler.remake_initial_projections() end_time = time.time() print(f'NOTICE: {end_time-start_time}') if tst_flag: for handler_group_id in range(len(self.multi_handler.tst_handlers_group)): tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id] self.model.assign_experts(tst_handlers, reca=False, log_assignment=True) recall, ndcg, tstnum = 0, 0, 0 for i, handler in enumerate(tst_handlers): reses = self.test_epoch(handler, i) log(self.make_print('handler.data_name', ep, reses, False)) recall += reses['Recall'] * reses['tstNum'] ndcg += reses['NDCG'] * reses['tstNum'] tstnum += reses['tstNum'] reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum} log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag)) if reses['NDCG'] > best_ndcg: best_ndcg = reses['NDCG'] best_ep = ep self.save_history() print() for test_group_id in range(len(self.multi_handler.tst_handlers_group)): repeat_times = 10 overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times) overall_tstnum = 0 tst_handlers = self.multi_handler.tst_handlers_group[test_group_id] if args.assignment == 'one-graph-one-expert': self.model.assign_experts(tst_handlers, reca=False, log_assignment=True) for i, handler in enumerate(tst_handlers): mets = dict() for _ in range(repeat_times): handler.make_projectors() if not args.assignment == 'one-graph-one-expert': self.model.assign_experts([handler], reca=False, log_assignment=False) reses = self.test_epoch(handler, i if args.assignment == 'one-graph-one-expert' else 0) log(self.make_print('Test', args.epoch, reses, False)) for met in reses: if met not in mets: mets[met] = [] mets[met].append(reses[met]) tstnum = reses['tstNum'] tot_reses = dict() for met in reses: tem_arr = np.array(mets[met]) tot_reses[met + '_std'] = tem_arr.std() tot_reses[met + '_mean'] = tem_arr.mean() overall_recall += np.array(mets['Acc']) * tstnum overall_ndcg += np.array(mets['F1']) * tstnum overall_tstnum += tstnum log(self.make_print(f'Test', args.epoch, tot_reses, False, handler.data_name)) overall_recall /= overall_tstnum overall_ndcg /= overall_tstnum overall_res = dict() overall_res['Recall_mean'] = overall_recall.mean() overall_res['Recall_std'] = overall_recall.std() overall_res['NDCG_mean'] = overall_ndcg.mean() overall_res['NDCG_std'] = overall_ndcg.std() log(self.make_print('Overall Test', args.epoch, overall_res, False)) self.save_history() def print_model_size(self): total_params = 0 trainable_params = 0 non_trainable_params = 0 for param in self.model.parameters(): tem = np.prod(param.size()) total_params += tem if param.requires_grad: trainable_params += tem else: non_trainable_params += tem print(f'Total params: {total_params/1e6}') print(f'Trainable params: {trainable_params/1e6}') print(f'Non-trainable params: {non_trainable_params/1e6}') def prepare_model(self): self.model = AnyGraph() t.cuda.empty_cache() self.print_model_size() def train_epoch(self): self.model.train() trn_loader = self.multi_handler.joint_trn_loader trn_loader.dataset.neg_sampling() ep_loss, ep_preloss, ep_regloss = 0, 0, 0 steps = len(trn_loader) tot_samp_num = 0 counter = [0] * len(self.multi_handler.trn_handlers) reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers))) for i, batch_data in enumerate(trn_loader): if args.epoch_max_step > 0 and i >= args.epoch_max_step: break ancs, poss, negs, dataset_id = batch_data ancs = ancs[0].long() poss = poss[0].long() negs = negs[0].long() dataset_id = dataset_id[0].long() tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all if tem_bar < 1.0 and np.random.uniform() > tem_bar: steps -= 1 continue expert = self.model.summon(dataset_id)#.cuda() opt = self.model.summon_opt(dataset_id) # adj = self.multi_handler.trn_handlers[dataset_id].trn_input_adj feats = self.multi_handler.trn_handlers[dataset_id].projectors loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats) opt.zero_grad() loss.backward() # nn.utils.clip_grad_norm_(expert.parameters(), max_norm=20, norm_type=2) opt.step() sample_num = ancs.shape[0] tot_samp_num += sample_num ep_loss += loss.item() * sample_num ep_preloss += loss_dict['preloss'].item() * sample_num ep_regloss += loss_dict['regloss'].item() log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True) counter[dataset_id] += 1 if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0: # if args.proj_trn_steps > 0 and counter[dataset_id] >= args.proj_trn_steps: self.multi_handler.trn_handlers[dataset_id].make_projectors() if (i + 1) % reassign_steps == 0: self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False) ret = dict() ret['Loss'] = ep_loss / tot_samp_num ret['preLoss'] = ep_preloss / tot_samp_num ret['regLoss'] = ep_regloss / steps t.cuda.empty_cache() return ret def make_trn_masks(self, numpy_usrs, csr_mat): trn_masks = csr_mat[numpy_usrs].tocoo() cand_size = trn_masks.shape[1] trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long() return trn_masks, cand_size def test_loss_epoch(self, handler, dataset_id): with t.no_grad(): tst_loader = handler.tst_loss_loader self.model.eval() expert = self.model.summon(dataset_id)#.cuda() ep_loss, ep_preloss, ep_regloss = 0, 0, 0 steps = len(tst_loader) tot_samp_num = 0 for i, batch_data in enumerate(tst_loader): ancs, poss, negs = batch_data ancs = ancs.long() poss = poss.long() negs = negs.long() # adj = handler.tst_input_adj feats = handler.projectors loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats) sample_num = ancs.shape[0] tot_samp_num += sample_num ep_loss += loss.item() * sample_num ep_preloss += loss_dict['preloss'].item() * sample_num ep_regloss += loss_dict['regloss'].item() log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True) ret = dict() ret['Loss'] = ep_loss / tot_samp_num ret['preLoss'] = ep_preloss / tot_samp_num ret['regLoss'] = ep_regloss / steps ret['tot_samp_num'] = tot_samp_num t.cuda.empty_cache() return ret def test_epoch(self, handler, dataset_id): with t.no_grad(): tst_loader = handler.tst_loader class_num = tst_loader.dataset.class_num self.model.eval() expert = self.model.summon(dataset_id) ep_acc, ep_tot = 0, 0 steps = len(tst_loader) for i, batch_data in enumerate(tst_loader): nodes, labels = list(map(lambda x: x.long().cuda(), batch_data)) feats = handler.projectors preds = expert.pred_for_node_test(nodes, class_num, feats, rerun_embed=False if i!=0 else True) if i == 0: all_preds, all_labels = preds, labels else: all_preds = t.concatenate([all_preds, preds]) all_labels = t.concatenate([all_labels, labels]) hit = (labels == preds).float().sum().item() ep_acc += hit ep_tot += labels.shape[0] log('Steps %d/%d: hit = %d, tot = %d ' % (i, steps, ep_acc, ep_tot), save=False, oneline=True) ret = dict() ret['Acc'] = ep_acc / ep_tot ret['F1'] = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro') ret['tstNum'] = ep_tot t.cuda.empty_cache() return ret def calc_recall_ndcg(self, topLocs, tstLocs, batIds): assert topLocs.shape[0] == len(batIds) allRecall = allNdcg = 0 for i in range(len(batIds)): temTopLocs = list(topLocs[i]) temTstLocs = tstLocs[batIds[i]] tstNum = len(temTstLocs) maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))]) recall = dcg = 0 for val in temTstLocs: if val in temTopLocs: recall += 1 dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2)) recall = recall / tstNum ndcg = dcg / maxDcg allRecall += recall allNdcg += ndcg return allRecall, allNdcg def save_history(self): if args.epoch == 0: return with open('../History/' + args.save_path + '.his', 'wb') as fs: pickle.dump(self.metrics, fs) content = { 'model': self.model, } t.save(content, '../Models/' + args.save_path + '.mod') log('Model Saved: %s' % args.save_path) def load_model(self): ckp = t.load('../Models/' + args.load_model + '.mod') self.model = ckp['model'] # self.model.set_initial_projection(self.handler.torch_adj) self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0) with open('../History/' + args.load_model + '.his', 'rb') as fs: self.metrics = pickle.load(fs) log('Model Loaded') if __name__ == '__main__': os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if len(args.gpu.split(',')) == 2: args.devices = ['cuda:0', 'cuda:1'] elif len(args.gpu.split(',')) > 2: raise Exception('Devices should be less than 2') else: args.devices = ['cuda:0', 'cuda:0'] logger.saveDefault = True setproctitle.setproctitle('akaxia_AutoGraph') log('Start') datasets = dict() datasets['all'] = [ 'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1', ] datasets['ecommerce'] = [ 'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech' ] datasets['academic'] = [ 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab' ] datasets['others'] = [ 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1' ] datasets['div1'] = [ 'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron', ] datasets['div2'] = [ 'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA', ] datasets['node'] = [ 'cora', 'arxiv', 'pubmed', 'home', 'tech' ] if args.dataset_setting in datasets.keys(): trn_datasets = tst_datasets = datasets[args.dataset_setting] elif args.dataset_setting in datasets['all']: trn_datasets = tst_datasets = [args.dataset_setting] elif '+' in args.dataset_setting: idx = args.dataset_setting.index('+') trn_datasets = datasets[args.dataset_setting[:idx]] tst_datasets = datasets[args.dataset_setting[idx+1:]] elif '_in_' in args.dataset_setting: idx = args.dataset_setting.index('_in_') tst_datasets_1 = datasets[args.dataset_setting[:idx]] tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]] tst_datasets = [] for data in tst_datasets_1: if data in tst_datasets_2: tst_datasets.append(data) trn_datasets = tst_datasets # trn_datasets = tst_datasets = ['products_home'] if '+' not in args.dataset_setting: handler = MultiDataHandler(trn_datasets, [tst_datasets]) else: handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets]) log('Load Data') exp = Exp(handler) exp.run() print(args.load_model, args.dataset_setting) ================================================ FILE: node_classification/model.py ================================================ import torch as t from torch import nn import torch.nn.functional as F from params import args import numpy as np from Utils.TimeLogger import log from torch.nn import MultiheadAttention from time import time init = nn.init.xavier_uniform_ # init = nn.init.normal_ uniformInit = nn.init.uniform_ class FeedForwardLayer(nn.Module): def __init__(self, in_feat, out_feat, bias=True, act=None): super(FeedForwardLayer, self).__init__() self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16) if act == 'identity' or act is None: self.act = None elif act == 'leaky': self.act = nn.LeakyReLU(negative_slope=args.leaky) elif act == 'relu': self.act = nn.ReLU() elif act == 'relu6': self.act = nn.ReLU6() else: raise Exception('Error') def forward(self, embeds): if self.act is None: return self.linear(embeds) return self.act(self.linear(embeds)) class TopoEncoder(nn.Module): def __init__(self): super(TopoEncoder, self).__init__() self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False) def forward(self, adj, embeds, normed=False): with t.no_grad(): if not normed: embeds = self.layer_norm(embeds) # embeds_list = [] final_embeds = 0 if args.gnn_layer == 0: final_embeds = embeds # embeds_list.append(embeds) for _ in range(args.gnn_layer): embeds = t.spmm(adj, embeds) final_embeds += embeds # embeds_list.append(embeds) embeds = final_embeds#sum(embeds_list) return embeds class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)]) self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)]) self.dropout = nn.Dropout(p=args.drop_rate) def forward(self, embeds): for i in range(args.fc_layer): embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds) return embeds class GTLayer(nn.Module): def __init__(self): super(GTLayer, self).__init__() self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16) self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16) self.fc_dropout = nn.Dropout(p=args.drop_rate) def _pick_anchors(self, embeds): perm = t.randperm(embeds.shape[0]) anchors = perm[:args.anchor] return embeds[anchors] def forward(self, embeds): anchor_embeds = self._pick_anchors(embeds) _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds) anchor_embeds = _anchor_embeds + anchor_embeds _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False) embeds = self.layer_norm1(_embeds + embeds) _embeds = self.fc_dropout(self.dense_layers(embeds)) embeds = (self.layer_norm2(_embeds + embeds)) return embeds class GraphTransformer(nn.Module): def __init__(self): super(GraphTransformer, self).__init__() self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)]) def forward(self, embeds): for i, layer in enumerate(self.gt_layers): embeds = layer(embeds) / args.scale_layer return embeds class Feat_Projector(nn.Module): def __init__(self, feats): super(Feat_Projector, self).__init__() if args.proj_method == 'uniform': self.proj_embeds = self.uniform_proj(feats) elif args.proj_method == 'svd' or args.proj_method == 'both': self.proj_embeds = self.svd_proj(feats) elif args.proj_method == 'random': self.proj_embeds = self.random_proj(feats) elif args.proj_method == 'original': self.proj_embeds = feats self.proj_embeds = t.flip(self.proj_embeds, dims=[-1]) self.proj_embeds = self.proj_embeds.detach() def svd_proj(self, feats): if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]: dim = min(feats.shape[0], feats.shape[1]) decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter) decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])]) else: decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter) decom_feats = decom_feats @ t.diag(t.sqrt(s)) return decom_feats.cpu() def uniform_proj(self, feats): projection = init(t.empty(args.featdim, args.latdim)) return feats @ projection def random_proj(self, feats): projection = init(t.empty(feats.shape[0], args.latdim)) return projection def forward(self): return self.proj_embeds class Adj_Projector(nn.Module): def __init__(self, adj): super(Adj_Projector, self).__init__() if args.proj_method == 'adj_svd' or args.proj_method == 'both': self.proj_embeds = self.svd_proj(adj) self.proj_embeds = self.proj_embeds.detach() def svd_proj(self, adj): q = args.latdim if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]: dim = min(adj.shape[0], adj.shape[1]) svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter) svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1) s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])]) else: svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter) svd_u = svd_u @ t.diag(t.sqrt(s)) svd_v = svd_v @ t.diag(t.sqrt(s)) if adj.shape[0] != adj.shape[1]: projection = t.concat([svd_u, svd_v], dim=0) else: projection = svd_u + svd_v return projection.cpu() def forward(self): return self.proj_embeds class Expert(nn.Module): def __init__(self): super(Expert, self).__init__() self.topo_encoder = TopoEncoder().to(args.devices[0]) if args.nn == 'mlp': self.trainable_nn = MLP().to(args.devices[1]) else: self.trainable_nn = GraphTransformer().to(args.devices[1]) self.trn_count = 1 def forward(self, projectors, pck_nodes=None): embeds = projectors.to(args.devices[1]) if pck_nodes is not None: embeds = embeds[pck_nodes] embeds = self.trainable_nn(embeds) return embeds def pred_norm(self, pos_preds, neg_preds): pos_preds_num = pos_preds.shape[0] neg_preds_shape = neg_preds.shape preds = t.concat([pos_preds, neg_preds.view(-1)]) preds = preds - preds.max() pos_preds = preds[:pos_preds_num] neg_preds = preds[pos_preds_num:].view(neg_preds_shape) return pos_preds, neg_preds def cal_loss(self, batch_data, projectors): ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data)) self.trn_count += ancs.shape[0] pck_nodes = t.concat([ancs, poss, negs]) final_embeds = self.forward(projectors, pck_nodes) # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs] anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3) if final_embeds.isinf().any() or final_embeds.isnan().any(): raise Exception('Final embedding fails') if args.loss == 'ce': pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T) if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any(): raise Exception('Preds fails') pos_loss = pos_preds neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log() pre_loss = -(pos_loss - neg_loss).mean() elif args.loss == 'bpr': pos_preds = (anc_embeds * pos_embeds).sum(-1) neg_preds = (anc_embeds * neg_embeds).sum(-1) pos_loss, neg_loss = pos_preds, neg_preds pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean() if t.isinf(pre_loss).any() or t.isnan(pre_loss).any(): raise Exception('NaN or Inf') reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters()))) loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()} return pre_loss + reg_loss, loss_dict def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True): ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data)) if rerun_embed: try: final_embeds = self.forward(projectors) except Exception: final_embeds_list = [] div = args.batch * 3 temlen = projectors.shape[0] // div for i in range(temlen): st, ed = div * i, div * (i + 1) tem_projectors = projectors[st: ed, :] final_embeds_list.append(self.forward(tem_projectors)) if temlen * div < projectors.shape[0]: tem_projectors = projectors[temlen*div:, :] final_embeds_list.append(self.forward(tem_projectors)) final_embeds = t.concat(final_embeds_list, dim=0) self.final_embeds = final_embeds final_embeds = self.final_embeds anc_embeds = final_embeds[ancs] cand_embeds = final_embeds[-cand_size:] mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size])) dense_mat = mask_mat.to_dense() all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8 return all_preds def pred_for_node_test(self, nodes, cand_size, feats, rerun_embed=True): if rerun_embed: final_embeds = self.forward(feats) self.final_embeds = final_embeds final_embeds = self.final_embeds anc_embeds = final_embeds[nodes] class_embeds = final_embeds[-cand_size:] preds = anc_embeds @ class_embeds.T return t.argmax(preds, dim=-1) def attempt(self, topo_embeds, dataset): final_embeds = self.trainable_nn(topo_embeds) rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs])) if rows.shape[0] > args.attempt_cache: random_perm = t.randperm(rows.shape[0], device=args.devices[0]) pck_perm = random_perm[:args.attempt_cache] rows = rows[pck_perm] cols = cols[pck_perm] negs = negs[pck_perm] while True: try: row_embeds = final_embeds[rows] col_embeds = final_embeds[cols] neg_embeds = final_embeds[negs] score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item() break except Exception: args.attempt_cache = args.attempt_cache // 2 random_perm = t.randperm(rows.shape[0], device=args.devices[0]) pck_perm = random_perm[:args.attempt_cache] rows = rows[pck_perm] cols = cols[pck_perm] negs = negs[pck_perm] t.cuda.empty_cache() return score class AnyGraph(nn.Module): def __init__(self): super(AnyGraph, self).__init__() self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda() self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts)) def one_graph_one_expert(self, handlers, log_assignment=False): self.assignment = [i for i in range(len(handlers))] if log_assignment: print('\n----------\nAssignment') for dataset_id, handler in enumerate(handlers): print(handler.data_name, f'{self.assignment[dataset_id]}') def store_history_assignment(self, assignment): pass def assign_experts(self, handlers, reca=True, log_assignment=False): if args.expert_num == 1: self.assignment = [0] * len(handlers) return if args.assignment == 'one-graph-one-expert': self.one_graph_one_expert(handlers, log_assignment) return if args.assignment not in ['top1', 'top2']: raise Exception('Unrecognized assigning methods') try: expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts))) expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2 except Exception: expert_scores = np.ones(len(self.experts)) with t.no_grad(): assignment = [list() for i in range(len(handlers))] for dataset_id, handler in enumerate(handlers): topo_embeds = handler.projectors.to(args.devices[1]) for expert_id, expert in enumerate(self.experts): expert = expert.to(args.devices[1]) score = expert.attempt(topo_embeds, handler.trn_loader.dataset) if reca: score *= expert_scores[expert_id] assignment[dataset_id].append((expert_id, score)) assignment[dataset_id].sort(key=lambda x: x[1], reverse=True) if args.assignment == 'top2': if assignment[dataset_id][0][1] - assignment[dataset_id][1][1] < 0.05 and np.random.uniform() > 0.5: tem = assignment[dataset_id][0] assignment[dataset_id][0] = assignment[dataset_id][1] assignment[dataset_id][1] = tem if log_assignment: print('\n----------\nAssignment') for dataset_id, handler in enumerate(handlers): out = '' for exp_idx in range(min(4, len(self.experts))): out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) ' print(handler.data_name, out) print('----------\n') self.assignment = list(map(lambda x: x[0][0], assignment)) def summon(self, dataset_id): return self.experts[self.assignment[dataset_id]] def summon_opt(self, dataset_id): return self.opts[self.assignment[dataset_id]] ================================================ FILE: node_classification/params.py ================================================ import argparse def parse_args(): parser = argparse.ArgumentParser(description='Model Parameters') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--batch', default=4096, type=int, help='training batch size') parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)') parser.add_argument('--epoch', default=100, type=int, help='number of epochs') parser.add_argument('--save_path', default='tem', help='file name to save model and training record') parser.add_argument('--load_model', default=None, help='model name to load') parser.add_argument('--data', default='ml1m', type=str, help='name of dataset') parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training') parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use') parser.add_argument('--topk', default=20, type=int, help='topk in evaluation') parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps') parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]') parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]') parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance') parser.add_argument('--ratio_fewshot_set', default=0.5, type=float, help='ratio of fewshot set') parser.add_argument('--shot', default=5, type=int, help='number of shots for each node') parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all') parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer') parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality') parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers') parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers') parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers') parser.add_argument('--head', default=4, type=int, help='number of attention heads') parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer') parser.add_argument('--act', default='relu', type=str, help='activation function') parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use') parser.add_argument('--assignment', default='top1', type=str, help='assigning method') parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor') parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation') parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout') parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration') parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not') parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd') parser.add_argument('--expert_num', default=8, type=int, help='number of experts') parser.add_argument('--loss', default='ce', type=str, help='loss function') parser.add_argument('--proj_method', default='both', type=str, help='feature projection method') parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use') parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection') parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection') return parser.parse_args() args = parse_args() ================================================ FILE: params.py ================================================ import argparse def parse_args(): parser = argparse.ArgumentParser(description='Model Parameters') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--batch', default=4096, type=int, help='training batch size') parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)') parser.add_argument('--epoch', default=100, type=int, help='number of epochs') parser.add_argument('--save_path', default='tem', help='file name to save model and training record') parser.add_argument('--load_model', default=None, help='model name to load') parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training') parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use') parser.add_argument('--topk', default=20, type=int, help='topk in evaluation') parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps') parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]') parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]') parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance') parser.add_argument('--ratio_fewshot', default=0.1, type=float, help='ratio of fewshot set') parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all') parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer') parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality') parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers') parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers') parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers') parser.add_argument('--head', default=4, type=int, help='number of attention heads') parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer') parser.add_argument('--act', default='relu', type=str, help='activation function') parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use') parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor') parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation') parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout') parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration') parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not') parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd') parser.add_argument('--expert_num', default=8, type=int, help='number of experts') parser.add_argument('--loss', default='ce', type=str, help='loss function') parser.add_argument('--proj_method', default='both', type=str, help='feature projection method') parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use') parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection') parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection') return parser.parse_args() args = parse_args()