Repository: HKUDS/AnyGraph
Branch: main
Commit: c2c5bfe103c4
Files: 15
Total size: 104.2 KB
Directory structure:
gitextract_t0oakqy3/
├── .gitignore
├── History/
│ ├── pretrain_link1.his
│ └── pretrain_link2.his
├── Models/
│ └── README.md
├── README.md
├── Utils/
│ └── TimeLogger.py
├── data_handler.py
├── main.py
├── model.py
├── node_classification/
│ ├── Utils/
│ │ └── TimeLogger.py
│ ├── data_handler.py
│ ├── main.py
│ ├── model.py
│ └── params.py
└── params.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
extract_code_structure.py
.DS_Store
================================================
FILE: Models/README.md
================================================
Download the pre-trained AnyGraph models from <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efmm5TJm0B5EnmYzTqg8GWEB1loKzeIR5tcr3hPIOJDXXA?e=2wMgZC'>this link</a>.
================================================
FILE: README.md
================================================
<h1 align='center'>AnyGraph: Graph Foundation Model in the Wild</h1>
<div align='center'>
<a href='https://arxiv.org/pdf/2408.10700'><img src='https://img.shields.io/badge/Paper-PDF-green'></a>
<!-- <a href=''><img src='https://img.shields.io/badge/公众号-blue' /></a> -->
<!-- <a href=''><img src='https://img.shields.io/badge/CSDN-orange' /></a> -->
<img src="https://badges.pufler.dev/visits/hkuds/anygraph?style=flat-square&logo=github">
<img src='https://img.shields.io/github/stars/hkuds/anygraph?color=green&style=social' />
<a href='https://akaxlh.github.io/'>Lianghao Xia</a> and <a href='https://sites.google.com/view/chaoh/group-join-us'>Chao Huang</a>
**Introducing AnyGraph, a graph foundation model designed for zero-shot predictions across domains.**
<img src='imgs/article cover.png' />
</div>
**Objectives of AnyGraph:**
- **Structure Heterogeneity**: Addressing distribution shift in graph structural information.
- **Feature Heterogeneity**: Handling diverse feature representation spaces across graph datasets.
- **Fast Adaptation**: Efficiently adapting the model to new graph domains.
- **Scaling Law Emergence**: Performance scales with the amount of data and model parameters.
<br>
**Key Features of AnyGraph:**
- **Graph Mixture-of-Experts (MoE)**: Effectively addresses cross-domain heterogeneity using an array of expert models.
- **Lightweight Graph Expert Routing Mechanism**: Enables swift adaptation to new datasets and domains.
- **Adaptive and Efficient Graph Experts**: Custom-designed to handle graphs with a wide range of structural patterns and feature spaces.
- **Extensively Trained and Tested**: Exhibits strong generalizability over 38 diverse graph datasets, showcasing scaling laws and emergent capabilities.
<img src='imgs/framework_final.jpeg' />
## News
- [x] 2024/09/18 Fix minor bugs on loading and saving path, class names, etc.
- [x] 2024/09/11 Datasets are available now!
- [x] 2024/09/11 Pre-trained models are available on huggingface now!
- [x] 2024/08/20 Paper is released.
- [x] 2024/08/20 Code is released.
## Environment Setup
Download the data files at <a href='https://huggingface.co/datasets/hkuds/AnyGraph_datasets'>this link</a>. And fill in your own directories for data storage at function `get_data_files(self)` of class `DataHandler` in the file `data_handler.py`.
Download the pre-trained AnyGraph models at <a href='https://huggingface.co/hkuds/AnyGraph/'>hugging face</a> or <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efmm5TJm0B5EnmYzTqg8GWEB1loKzeIR5tcr3hPIOJDXXA?e=2wMgZC'>one drive</a>, and put it into `Models/`.
**Packages**: Our experiments were conducted with the following package versions:
* python==3.10.13
* torch==1.13.0
* numpy==1.23.4
* scipy==1.9.3
**Device Requirements**: The training and testing of AnyGraph requires only one GPU with 24G memory (e.g. 3090, 4090). Using larger input graphs may require devices with larger memory.
## Code Structure
Here is a brief overview of the code structures. The explanations for each directory are enclosed in quotes (##...##).
```
./
│ ├── .gitignore
│ ├── README.md
│ ├── data_handler.py ## load and process data
│ ├── model.py ## implementation for the model
│ ├── params.py ## hyperparameters
│ └── main.py ## main file for pretraining and link prediction
│ ├── History/ ## training and testing logs
│ │ ├── pretrain_link1.his
│ │ └── pretrain_link2.his
│ ├── Models/ ## pre-trained models
│ │ └── README.md
│ ├── Utils/ ## utility function
│ │ └── TimeLogger.py
│ ├── imgs/ ## images used in readme
│ │ ├── ablation.png
│ │ ├── article cover.png
│ │ ├── datasets.png
│ │ ├── framework_final.jpeg
│ │ ├── framework_final.pdf
│ │ ├── framework_final.png
│ │ ├── overall_performance2.png
│ │ ├── routing.png
│ │ ├── scaling_law.png
│ │ ├── training_time.png
│ │ ├── tuning_steps.png
│ │ └── overall_performance1.png
│ ├── node_classification/ ## test code for node classification
│ │ ├── data_handler.py
│ │ ├── model.py
│ │ ├── params.py
│ │ └── main.py
│ │ ├── Utils/
│ │ │ └── TimeLogger.py
```
## Usage
To reproduce the test performance reported in the paper, run the following command lines:
```
# Test on Link2 and Link1 data, respectively
python main.py --load pretrain_link1 --epoch 0 --dataset link2
python main.py --load pretrain_link2 --epoch 0 --dataset link1
# Test on the Ecommerce datasets in the Link2 and Link1 group, respectively.
# Testing on Academic and Others datasets are conducted similarily.
python main.py --load pretrain_link1 --epoch 0 --dataset ecommerce_in_link2
python main.py --load pretrain_link2 --epoch 0 --dataset ecommerce_in_link1
# Test the performance for node classification datasets
cd ./node_classification
python main.py --load pretrain_link2 --epoch 0 --dataset node
```
To re-train the two models by yourself, run:
```
python main.py --dataset link2+link1 --save pretrain_link2
python main.py --dataset link1+link2 --save pretrain_link1
```
## Datasets
<img src='imgs/datasets.png' />
The statistics for the experimental datasets are presented in the table above. We categorize them into distinct groups as below. Note that Link1 and Link2 include datasets from different sources, and the datasets do not share the same feature spaces. This separation ensures a robust evaluation of true zero-shot performance in graph prediction tasks.
| Group | Included Datasets |
| ----- | ----- |
| Link1 | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, P2P-Gnutella06, Soc-Epinions1, Email-Enron |
| Link2 | Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla, Arxiv, Arxiv-t, Cora, CS, OGB-Collab, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA |
| Ecommerce | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla |
| Academic | Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, Arxiv, Arxiv-t, Cora, CS, OGB-Collab |
| Others | P2P-Gnutella06, Soc-Epinions1, Email-Enron, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA |
| Node | Cora, Arxiv, Pubmed, Home, Tech |
## Experiments
### Model Pre-Training Curves
We present the training logs with respect to epochs below. Each figure contains two curves, each corresponding to two instances of repeated pre-training.
- pretrain_link1
<img src='imgs/link1_loss_curve.png' width=32%/>
<img src='imgs/link1_fullshot_ndcg_curve.png' width=32%/>
<img src='imgs/link1_zeroshot_ndcg_curve.png' width=32%/>
- pretrain_link2
<img src='imgs/link2_loss_curve.png' width=32%/>
<img src='imgs/link2_fullshot_ndcg_curve.png' width=32%/>
<img src='imgs/link2_zeroshot_ndcg_curve.png' width=32%/>
### Overall Performance Comparison
- Comparing to few-shot end2end models and pre-training and fine-tuning methods.

- Comparing to zero-shot graph foundation models.
<img src='imgs/overall_performance2.png' width=60%/>
### Scaling Law of AnyGraph
We explore the scaling law of AnyGraph by evaluating 1) model performance v.s. the number of model parameters, and 2) model performance v.s. the number of training samples.
Below shows the evaluation results on
- all datasets across domains (a)
- academic datasets (b)
- ecommerce datasets (c)
- other datasets (d)
In each subfigure, we show
- zero-shot performance on unseen datasets w.r.t. the amount of model parameters (left)
- full-shot performance on training datasets w.r.t. the amount of model parameters (middle)
- zero-shot performance w.r.t. the amount of training data (right)

The outcome outlines the following key observations: (see Sec. 4.3 for details)
- Generalizability of AnyGraph Follows the Scaling Law.
- Emergent Abilities of AnyGraph.
- Insufficient Training Data May Bring Bias.
### Ablation Study
The ablation study investigates the impact of the following modules:
- The overall MoE architecture
- Frequency regularization in the expert routing mechanism
- Graph augmentation in the learning process
- The utilization of (heterogeneous) node features from different datasets
<img src='imgs/ablation.png' width=60% />
### Expert Routing Mechanism
We visualize the competence scores between datasets and experts, given by the routing algorithm of AnyGraph.
The resulting scores below demonstrates the underlying relatedness between different datasets, thus demonstrating the intuitive effectiveness of the routing mechanism. (see Sec. 4.5 for details)
<img src='imgs/routing.png' width=60% />
### Fast Adaptation of AnyGraph
We study the fast adaptation abilities of AnyGraph from two aspects:
- When fine-tuned on unseen datasets, AnyGraph achieves better performance with less training steps. (Fig. 6 below)
- The training time of AnyGraph is comparative to that of other methods. (Table 3 below)
<img src='imgs/tuning_steps.png' width=60% />
<img src='imgs/training_time.png' width=60% />
## Citation
If you find our work useful, please consider citing our paper:
```
@article{xia2024anygraph,
title={AnyGraph: Graph Foundation Model in the Wild},
author={Xia, Lianghao and Huang, Chao},
journal={arXiv preprint arXiv:2408.10700},
year={2024}
}
```
================================================
FILE: Utils/TimeLogger.py
================================================
import datetime
logmsg = ''
timemark = dict()
saveDefault = False
def log(msg, save=None, oneline=False):
global logmsg
global saveDefault
time = datetime.datetime.now()
tem = '%s: %s' % (time, msg)
if save != None:
if save:
logmsg += tem + '\n'
elif saveDefault:
logmsg += tem + '\n'
if oneline:
print(tem, end='\r')
else:
print(tem)
def marktime(marker):
global timemark
timemark[marker] = datetime.datetime.now()
if __name__ == '__main__':
log('')
================================================
FILE: data_handler.py
================================================
import pickle
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
from params import args
import scipy.sparse as sp
from Utils.TimeLogger import log
import torch as t
import torch.utils.data as data
import torch_geometric.transforms as T
from model import Feat_Projector, Adj_Projector, TopoEncoder
import os
class MultiDataHandler:
def __init__(self, trn_datasets, tst_datasets_group):
all_datasets = trn_datasets
all_tst_datasets = []
for tst_datasets in tst_datasets_group:
all_datasets = all_datasets + tst_datasets
all_tst_datasets += tst_datasets
all_datasets = list(set(all_datasets))
all_datasets.sort()
self.trn_handlers = []
self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))]
for data_name in all_datasets:
trn_flag = data_name in trn_datasets
tst_flag = data_name in all_tst_datasets
handler = DataHandler(data_name)
if trn_flag:
self.trn_handlers.append(handler)
if tst_flag:
for i in range(len(tst_datasets_group)):
if data_name in tst_datasets_group[i]:
self.tst_handlers_group[i].append(handler)
self.make_joint_trn_loader()
def make_joint_trn_loader(self):
loader_datasets = []
for trn_handler in self.trn_handlers:
tem_dataset = trn_handler.trn_loader.dataset
loader_datasets.append(tem_dataset)
joint_dataset = JointTrnData(loader_datasets)
self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0)
def remake_initial_projections(self):
for i in range(len(self.trn_handlers)):
trn_handler = self.trn_handlers[i]
trn_handler.make_projectors()
class DataHandler:
def __init__(self, data_name):
self.data_name = data_name
self.get_data_files()
log(f'Loading dataset {data_name}')
self.topo_encoder = TopoEncoder()
self.load_data()
def get_data_files(self):
predir = f'/home/user_name/data/zero-shot datasets/{self.data_name}/'
if os.path.exists(predir + 'feats.pkl'):
self.feat_file = predir + 'feats.pkl'
else:
self.feat_file = None
self.trnfile = predir + 'trn_mat.pkl'
self.tstfile = predir + 'tst_mat.pkl'
self.fewshotfile = predir + 'partial_mat_{shot}.pkl'.format(shot=args.ratio_fewshot)
self.valfile = predir + 'val_mat.pkl'
def load_one_file(self, filename):
with open(filename, 'rb') as fs:
ret = (pickle.load(fs) != 0).astype(np.float32)
if type(ret) != coo_matrix:
ret = sp.coo_matrix(ret)
return ret
def load_feats(self, filename):
try:
with open(filename, 'rb') as fs:
feats = pickle.load(fs)
except Exception as e:
print(filename + str(e))
exit()
return feats
def normalize_adj(self, mat, log=False):
degree = np.array(mat.sum(axis=-1))
d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
if mat.shape[0] == mat.shape[1]:
return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()
else:
tem = d_inv_sqrt_mat.dot(mat)
col_degree = np.array(mat.sum(axis=0))
d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
return tem.dot(d_inv_sqrt_mat).tocoo()
def unique_numpy(self, row, col):
hash_vals = row * args.node_num + col
hash_vals = np.unique(hash_vals).astype(np.int64)
col = hash_vals % args.node_num
row = (hash_vals - col).astype(np.int64) // args.node_num
return row, col
def make_torch_adj(self, mat, unidirectional_for_asym=False):
if mat.shape[0] == mat.shape[1]:
_row = mat.row
_col = mat.col
row = np.concatenate([_row, _col]).astype(np.int64)
col = np.concatenate([_col, _row]).astype(np.int64)
row, col = self.unique_numpy(row, col)
data = np.ones_like(row)
mat = coo_matrix((data, (row, col)), mat.shape)
if args.selfloop == 1:
mat = (mat + sp.eye(mat.shape[0])) * 1.0
normed_asym_mat = self.normalize_adj(mat)
row = t.from_numpy(normed_asym_mat.row).long()
col = t.from_numpy(normed_asym_mat.col).long()
idxs = t.stack([row, col], dim=0)
vals = t.from_numpy(normed_asym_mat.data).float()
shape = t.Size(normed_asym_mat.shape)
asym_adj = t.sparse.FloatTensor(idxs, vals, shape)
return asym_adj
elif unidirectional_for_asym:
mat = (mat != 0) * 1.0
mat = self.normalize_adj(mat, log=True)
idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
vals = t.from_numpy(mat.data.astype(np.float32))
shape = t.Size(mat.shape)
return t.sparse.FloatTensor(idxs, vals, shape)
else:
# make ui adj
a = sp.csr_matrix((args.user_num, args.user_num))
b = sp.csr_matrix((args.item_num, args.item_num))
mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
mat = (mat != 0) * 1.0
if args.selfloop == 1:
mat = (mat + sp.eye(mat.shape[0])) * 1.0
mat = self.normalize_adj(mat)
# make cuda tensor
idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
vals = t.from_numpy(mat.data.astype(np.float32))
shape = t.Size(mat.shape)
return t.sparse.FloatTensor(idxs, vals, shape)
def load_data(self):
tst_mat = self.load_one_file(self.tstfile)
val_mat = self.load_one_file(self.valfile)
trn_mat = self.load_one_file(self.trnfile)
fewshot_mat = self.load_one_file(self.fewshotfile)
if self.feat_file is not None:
self.feats = t.from_numpy(self.load_feats(self.feat_file)).float()
self.feats = self.feats
args.featdim = self.feats.shape[1]
else:
self.feats = None
args.featdim = args.latdim
if trn_mat.shape[0] != trn_mat.shape[1]:
args.user_num, args.item_num = trn_mat.shape
args.node_num = args.user_num + args.item_num
print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz))
else:
args.node_num = trn_mat.shape[0]
print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+val_mat.nnz+tst_mat.nnz))
if args.tst_mode == 'tst':
tst_data = TstData(tst_mat, trn_mat)
self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
self.tst_input_adj = self.make_torch_adj(trn_mat)
elif args.tst_mode == 'val':
tst_data = TstData(val_mat, trn_mat)
self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
self.tst_input_adj = self.make_torch_adj(fewshot_mat)
else:
raise Exception('Specify proper test mode')
if args.trn_mode == 'fewshot':
self.trn_mat = fewshot_mat
trn_data = TrnData(self.trn_mat)
self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)
self.trn_input_adj = self.make_torch_adj(fewshot_mat)
if args.tst_mode == 'val':
self.trn_input_adj = self.tst_input_adj
else:
self.trn_input_adj = self.make_torch_adj(fewshot_mat)
elif args.trn_mode == 'train-all':
self.trn_mat = trn_mat
trn_data = TrnData(self.trn_mat)
self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)
if args.tst_mode == 'tst':
self.trn_input_adj = self.tst_input_adj
else:
self.trn_input_adj = self.make_torch_adj(trn_mat)
else:
raise Exception('Specify proper train mode')
if self.trn_mat.shape[0] == self.trn_mat.shape[1]:
self.asym_adj = self.trn_input_adj
else:
self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True)
self.make_projectors()
self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps)
self.ratio_500_all = 500 / len(self.trn_loader)
def make_projectors(self):
with t.no_grad():
projectors = []
if args.proj_method == 'adj_svd' or args.proj_method == 'both':
tem = self.asym_adj.to(args.devices[0])
projectors = [Adj_Projector(tem)]
if self.feats is not None and args.proj_method != 'adj_svd':
tem = self.feats.to(args.devices[0])
projectors.append(Feat_Projector(tem))
assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot'
feats = projectors[0]()
if len(projectors) == 2:
feats2 = projectors[1]()
feats = feats + feats2
try:
self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu()
except Exception:
print(f'{self.data_name} memory overflow')
mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True)
tem_adj = self.trn_input_adj.to(args.devices[0])
mem_cache = 256
projectors_list = []
for i in range(feats.shape[1] // mem_cache):
st, ed = i * mem_cache, (i + 1) * mem_cache
tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8)
tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu()
projectors_list.append(tem_feats)
self.projectors = t.concat(projectors_list, dim=-1)
t.cuda.empty_cache()
class TstData(data.Dataset):
def __init__(self, coomat, trn_mat):
self.csrmat = (trn_mat.tocsr() != 0) * 1.0
tstLocs = [None] * coomat.shape[0]
tst_nodes = set()
for i in range(len(coomat.data)):
row = coomat.row[i]
col = coomat.col[i]
if tstLocs[row] is None:
tstLocs[row] = list()
tstLocs[row].append(col)
tst_nodes.add(row)
tst_nodes = np.array(list(tst_nodes))
self.tst_nodes = tst_nodes
self.tstLocs = tstLocs
def __len__(self):
return len(self.tst_nodes)
def __getitem__(self, idx):
return self.tst_nodes[idx]
class TrnData(data.Dataset):
def __init__(self, coomat):
self.ancs, self.poss = coomat.row, coomat.col
self.negs = np.zeros(len(self.ancs)).astype(np.int32)
self.cand_num = coomat.shape[1]
self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0]
self.poss = coomat.col + self.neg_shift
self.neg_sampling()
def neg_sampling(self):
self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0])
def __len__(self):
return len(self.ancs)
def __getitem__(self, idx):
return self.ancs[idx], self.poss[idx] , self.negs[idx]
class JointTrnData(data.Dataset):
def __init__(self, dataset_list):
self.batch_dataset_ids = []
self.batch_st_ed_list = []
self.dataset_list = dataset_list
for dataset_id, dataset in enumerate(dataset_list):
samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0)
for j in range(samp_num):
self.batch_dataset_ids.append(dataset_id)
st = j * args.batch
ed = min((j + 1) * args.batch, len(dataset))
self.batch_st_ed_list.append((st, ed))
def neg_sampling(self):
for dataset in self.dataset_list:
dataset.neg_sampling()
def __len__(self):
return len(self.batch_dataset_ids)
def __getitem__(self, idx):
st, ed = self.batch_st_ed_list[idx]
dataset_id = self.batch_dataset_ids[idx]
return *self.dataset_list[dataset_id][st: ed], dataset_id
================================================
FILE: main.py
================================================
import torch as t
from torch import nn
import Utils.TimeLogger as logger
from Utils.TimeLogger import log
from params import args
from model import Expert, Feat_Projector, Adj_Projector, AnyGraph
from data_handler import MultiDataHandler, DataHandler
import numpy as np
import pickle
import os
import setproctitle
import time
class Exp:
def __init__(self, multi_handler):
self.multi_handler = multi_handler
print(list(map(lambda x: x.data_name, multi_handler.trn_handlers)))
for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group):
print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers)))
self.metrics = dict()
trn_mets = ['Loss', 'preLoss']
tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss']
mets = trn_mets + tst_mets
for met in mets:
if met in trn_mets:
self.metrics['Train' + met] = list()
if met in tst_mets:
for i in range(len(self.multi_handler.tst_handlers_group)):
self.metrics['Test' + str(i) + met] = list()
def make_print(self, name, ep, reses, save, data_name=None):
if data_name is None:
ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
else:
ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)
for metric in reses:
val = reses[metric]
ret += '%s = %.4f, ' % (metric, val)
tem = name + metric if data_name is None else name + data_name + metric
if save and tem in self.metrics:
self.metrics[tem].append(val)
ret = ret[:-2] + ' '
return ret
def run(self):
self.prepare_model()
log('Model Prepared')
stloc = 0
if args.load_model != None:
self.load_model()
stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)
best_ndcg, best_ep = 0, -1
for ep in range(stloc, args.epoch):
tst_flag = (ep % args.tst_epoch == 0)
start_time = time.time()
self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True)
reses = self.train_epoch()
log(self.make_print('Train', ep, reses, tst_flag))
self.multi_handler.remake_initial_projections()
end_time = time.time()
print(f'NOTICE: {end_time-start_time}')
if tst_flag:
for handler_group_id in range(len(self.multi_handler.tst_handlers_group)):
tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id]
self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)
recall, ndcg, tstnum = 0, 0, 0
for i, handler in enumerate(tst_handlers):
reses = self.test_epoch(handler, i)
# log(self.make_print(f'{handler.data_name}', ep, reses, False))
recall += reses['Recall'] * reses['tstNum']
ndcg += reses['NDCG'] * reses['tstNum']
tstnum += reses['tstNum']
reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum}
log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag))
if reses['NDCG'] > best_ndcg:
best_ndcg = reses['NDCG']
best_ep = ep
self.save_history()
print()
for test_group_id in range(len(self.multi_handler.tst_handlers_group)):
repeat_times = 5
overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times)
overall_tstnum = 0
tst_handlers = self.multi_handler.tst_handlers_group[test_group_id]
for i, handler in enumerate(tst_handlers):
for topk in [args.topk]:
args.topk = topk
mets = dict()
for _ in range(repeat_times):
handler.make_projectors()
self.model.assign_experts([handler], reca=False, log_assignment=False)
reses = self.test_epoch(handler, 0)
for met in reses:
if met not in mets:
mets[met] = []
mets[met].append(reses[met])
tstnum = reses['tstNum']
tot_reses = dict()
for met in reses:
tem_arr = np.array(mets[met])
tot_reses[met + '_std'] = tem_arr.std()
tot_reses[met + '_mean'] = tem_arr.mean()
if topk == args.topk:
overall_recall += np.array(mets['Recall']) * tstnum
overall_ndcg += np.array(mets['NDCG']) * tstnum
overall_tstnum += tstnum
log(self.make_print(f'Test Top-{topk}', args.epoch, tot_reses, False, handler.data_name))
overall_recall /= overall_tstnum
overall_ndcg /= overall_tstnum
overall_res = dict()
overall_res['Recall_mean'] = overall_recall.mean()
overall_res['Recall_std'] = overall_recall.std()
overall_res['NDCG_mean'] = overall_ndcg.mean()
overall_res['NDCG_std'] = overall_ndcg.std()
log(self.make_print('Overall Test', args.epoch, overall_res, False))
self.save_history()
def print_model_size(self):
total_params = 0
trainable_params = 0
non_trainable_params = 0
for param in self.model.parameters():
tem = np.prod(param.size())
total_params += tem
if param.requires_grad:
trainable_params += tem
else:
non_trainable_params += tem
print(f'Total params: {total_params/1e6}')
print(f'Trainable params: {trainable_params/1e6}')
print(f'Non-trainable params: {non_trainable_params/1e6}')
def prepare_model(self):
self.model = AnyGraph()
t.cuda.empty_cache()
self.print_model_size()
def train_epoch(self):
self.model.train()
trn_loader = self.multi_handler.joint_trn_loader
trn_loader.dataset.neg_sampling()
ep_loss, ep_preloss, ep_regloss = 0, 0, 0
steps = len(trn_loader)
tot_samp_num = 0
counter = [0] * len(self.multi_handler.trn_handlers)
reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers)))
for i, batch_data in enumerate(trn_loader):
if args.epoch_max_step > 0 and i >= args.epoch_max_step:
break
ancs, poss, negs, dataset_id = batch_data
ancs = ancs[0].long()
poss = poss[0].long()
negs = negs[0].long()
dataset_id = dataset_id[0].long()
tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all
if tem_bar < 1.0 and np.random.uniform() > tem_bar:
steps -= 1
continue
expert = self.model.summon(dataset_id)
opt = self.model.summon_opt(dataset_id)
feats = self.multi_handler.trn_handlers[dataset_id].projectors
loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)
opt.zero_grad()
loss.backward()
opt.step()
sample_num = ancs.shape[0]
tot_samp_num += sample_num
ep_loss += loss.item() * sample_num
ep_preloss += loss_dict['preloss'].item() * sample_num
ep_regloss += loss_dict['regloss'].item()
log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)
counter[dataset_id] += 1
if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0:
self.multi_handler.trn_handlers[dataset_id].make_projectors()
if (i + 1) % reassign_steps == 0:
self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False)
ret = dict()
ret['Loss'] = ep_loss / tot_samp_num
ret['preLoss'] = ep_preloss / tot_samp_num
ret['regLoss'] = ep_regloss / steps
t.cuda.empty_cache()
return ret
def make_trn_masks(self, numpy_usrs, csr_mat):
trn_masks = csr_mat[numpy_usrs].tocoo()
cand_size = trn_masks.shape[1]
trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long()
return trn_masks, cand_size
def test_epoch(self, handler, dataset_id):
with t.no_grad():
tst_loader = handler.tst_loader
self.model.eval()
expert = self.model.summon(dataset_id)
ep_recall, ep_ndcg = 0, 0
ep_tstnum = len(tst_loader.dataset)
steps = max(ep_tstnum // args.tst_batch, 1)
for i, batch_data in enumerate(tst_loader):
if args.tst_steps != -1 and i > args.tst_steps:
break
usrs = batch_data.long()
trn_masks, cand_size = self.make_trn_masks(batch_data.numpy(), tst_loader.dataset.csrmat)
feats = handler.projectors
all_preds = expert.pred_for_test((usrs, trn_masks), cand_size, feats, rerun_embed=False if i!=0 else True)
_, top_locs = t.topk(all_preds, args.topk)
top_locs = top_locs.cpu().numpy()
recall, ndcg = self.calc_recall_ndcg(top_locs, tst_loader.dataset.tstLocs, usrs)
ep_recall += recall
ep_ndcg += ndcg
log('Steps %d/%d: recall = %.2f, ndcg = %.2f ' % (i, steps, recall, ndcg), save=False, oneline=True)
ret = dict()
if args.tst_steps != -1:
ep_tstnum = args.tst_steps * args.tst_batch
ret['Recall'] = ep_recall / ep_tstnum
ret['NDCG'] = ep_ndcg / ep_tstnum
ret['tstNum'] = ep_tstnum
t.cuda.empty_cache()
return ret
def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
assert topLocs.shape[0] == len(batIds)
allRecall = allNdcg = 0
for i in range(len(batIds)):
temTopLocs = list(topLocs[i])
temTstLocs = tstLocs[batIds[i]]
tstNum = len(temTstLocs)
maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])
recall = dcg = 0
for val in temTstLocs:
if val in temTopLocs:
recall += 1
dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
recall = recall / tstNum
ndcg = dcg / maxDcg
allRecall += recall
allNdcg += ndcg
return allRecall, allNdcg
def save_history(self):
if args.epoch == 0:
return
with open('./History/' + args.save_path + '.his', 'wb') as fs:
pickle.dump(self.metrics, fs)
content = {
'model': self.model,
}
t.save(content, './Models/' + args.save_path + '.mod')
log('Model Saved: %s' % args.save_path)
def load_model(self):
ckp = t.load('./Models/' + args.load_model + '.mod')
self.model = ckp['model']
self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)
with open('./History/' + args.load_model + '.his', 'rb') as fs:
self.metrics = pickle.load(fs)
log('Model Loaded')
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if len(args.gpu.split(',')) == 2:
args.devices = ['cuda:0', 'cuda:1']
elif len(args.gpu.split(',')) > 2:
raise Exception('Devices should be less than 2')
else:
args.devices = ['cuda:0', 'cuda:0']
logger.saveDefault = True
setproctitle.setproctitle('AnyGraph')
log('Start')
datasets = dict()
datasets['all'] = [
'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'
]
datasets['ecommerce'] = [
'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech'
]
datasets['academic'] = [
'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab'
]
datasets['others'] = [
'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'
]
datasets['link1'] = [
'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron',
]
datasets['link2'] = [
'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA',
]
if args.dataset_setting in datasets.keys():
trn_datasets = tst_datasets = datasets[args.dataset_setting]
elif args.dataset_setting in datasets['all']:
trn_datasets = tst_datasets = [args.dataset_setting]
elif '+' in args.dataset_setting:
idx = args.dataset_setting.index('+')
trn_datasets = datasets[args.dataset_setting[:idx]]
tst_datasets = datasets[args.dataset_setting[idx+1:]]
elif '_in_' in args.dataset_setting:
idx = args.dataset_setting.index('_in_')
tst_datasets_1 = datasets[args.dataset_setting[:idx]]
tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]]
tst_datasets = []
for data in tst_datasets_1:
if data in tst_datasets_2:
tst_datasets.append(data)
trn_datasets = tst_datasets
if '+' not in args.dataset_setting:
# No zero-shot prediction test
handler = MultiDataHandler(trn_datasets, [tst_datasets])
else:
handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets])
log('Load Data')
exp = Exp(handler)
exp.run()
print(args.load_model, args.dataset_setting)
================================================
FILE: model.py
================================================
import torch as t
from torch import nn
import torch.nn.functional as F
from params import args
import numpy as np
from Utils.TimeLogger import log
from torch.nn import MultiheadAttention
from time import time
init = nn.init.xavier_uniform_
uniformInit = nn.init.uniform_
class FeedForwardLayer(nn.Module):
def __init__(self, in_feat, out_feat, bias=True, act=None):
super(FeedForwardLayer, self).__init__()
self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)
if act == 'identity' or act is None:
self.act = None
elif act == 'leaky':
self.act = nn.LeakyReLU(negative_slope=args.leaky)
elif act == 'relu':
self.act = nn.ReLU()
elif act == 'relu6':
self.act = nn.ReLU6()
else:
raise Exception('Error')
def forward(self, embeds):
if self.act is None:
return self.linear(embeds)
return self.act(self.linear(embeds))
class TopoEncoder(nn.Module):
def __init__(self):
super(TopoEncoder, self).__init__()
self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)
def forward(self, adj, embeds, normed=False):
with t.no_grad():
if not normed:
embeds = self.layer_norm(embeds)
# embeds_list = []
final_embeds = 0
if args.gnn_layer == 0:
final_embeds = embeds
# embeds_list.append(embeds)
for _ in range(args.gnn_layer):
embeds = t.spmm(adj, embeds)
final_embeds += embeds
# embeds_list.append(embeds)
embeds = final_embeds#sum(embeds_list)
return embeds
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)])
self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)])
self.dropout = nn.Dropout(p=args.drop_rate)
def forward(self, embeds):
for i in range(args.fc_layer):
embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds)
return embeds
class GTLayer(nn.Module):
def __init__(self):
super(GTLayer, self).__init__()
self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)
self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False
self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
self.fc_dropout = nn.Dropout(p=args.drop_rate)
def _pick_anchors(self, embeds):
perm = t.randperm(embeds.shape[0])
anchors = perm[:args.anchor]
return embeds[anchors]
def forward(self, embeds):
anchor_embeds = self._pick_anchors(embeds)
_anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)
anchor_embeds = _anchor_embeds + anchor_embeds
_embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)
embeds = self.layer_norm1(_embeds + embeds)
_embeds = self.fc_dropout(self.dense_layers(embeds))
embeds = (self.layer_norm2(_embeds + embeds))
return embeds
class GraphTransformer(nn.Module):
def __init__(self):
super(GraphTransformer, self).__init__()
self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])
def forward(self, embeds):
for i, layer in enumerate(self.gt_layers):
embeds = layer(embeds) / args.scale_layer
return embeds
class Feat_Projector(nn.Module):
def __init__(self, feats):
super(Feat_Projector, self).__init__()
if args.proj_method == 'uniform':
self.proj_embeds = self.uniform_proj(feats)
elif args.proj_method == 'svd' or args.proj_method == 'both':
self.proj_embeds = self.svd_proj(feats)
elif args.proj_method == 'random':
self.proj_embeds = self.random_proj(feats)
elif args.proj_method == 'original':
self.proj_embeds = feats
self.proj_embeds = t.flip(self.proj_embeds, dims=[-1])
self.proj_embeds = self.proj_embeds.detach()
def svd_proj(self, feats):
if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]:
dim = min(feats.shape[0], feats.shape[1])
decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter)
decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])])
else:
decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter)
decom_feats = decom_feats @ t.diag(t.sqrt(s))
return decom_feats.cpu()
def uniform_proj(self, feats):
projection = init(t.empty(args.featdim, args.latdim))
return feats @ projection
def random_proj(self, feats):
projection = init(t.empty(feats.shape[0], args.latdim))
return projection
def forward(self):
return self.proj_embeds
class Adj_Projector(nn.Module):
def __init__(self, adj):
super(Adj_Projector, self).__init__()
if args.proj_method == 'adj_svd' or args.proj_method == 'both':
self.proj_embeds = self.svd_proj(adj)
self.proj_embeds = self.proj_embeds.detach()
def svd_proj(self, adj):
q = args.latdim
if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:
dim = min(adj.shape[0], adj.shape[1])
svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)
svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])
else:
svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)
svd_u = svd_u @ t.diag(t.sqrt(s))
svd_v = svd_v @ t.diag(t.sqrt(s))
if adj.shape[0] != adj.shape[1]:
projection = t.concat([svd_u, svd_v], dim=0)
else:
projection = svd_u + svd_v
return projection.cpu()
def forward(self):
return self.proj_embeds
class Expert(nn.Module):
def __init__(self):
super(Expert, self).__init__()
self.topo_encoder = TopoEncoder().to(args.devices[0])
if args.nn == 'mlp':
self.trainable_nn = MLP().to(args.devices[1])
else:
self.trainable_nn = GraphTransformer().to(args.devices[1])
self.trn_count = 1
def forward(self, projectors, pck_nodes=None):
embeds = projectors.to(args.devices[1])
if pck_nodes is not None:
embeds = embeds[pck_nodes]
embeds = self.trainable_nn(embeds)
return embeds
def pred_norm(self, pos_preds, neg_preds):
pos_preds_num = pos_preds.shape[0]
neg_preds_shape = neg_preds.shape
preds = t.concat([pos_preds, neg_preds.view(-1)])
preds = preds - preds.max()
pos_preds = preds[:pos_preds_num]
neg_preds = preds[pos_preds_num:].view(neg_preds_shape)
return pos_preds, neg_preds
def cal_loss(self, batch_data, projectors):
ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data))
self.trn_count += ancs.shape[0]
pck_nodes = t.concat([ancs, poss, negs])
final_embeds = self.forward(projectors, pck_nodes)
# anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]
anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3)
if final_embeds.isinf().any() or final_embeds.isnan().any():
raise Exception('Final embedding fails')
if args.loss == 'ce':
pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)
if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():
raise Exception('Preds fails')
pos_loss = pos_preds
neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()
pre_loss = -(pos_loss - neg_loss).mean()
elif args.loss == 'bpr':
pos_preds = (anc_embeds * pos_embeds).sum(-1)
neg_preds = (anc_embeds * neg_embeds).sum(-1)
pos_loss, neg_loss = pos_preds, neg_preds
pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean()
if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():
raise Exception('NaN or Inf')
reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))
loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}
return pre_loss + reg_loss, loss_dict
def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True):
ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data))
if rerun_embed:
try:
final_embeds = self.forward(projectors)
except Exception:
final_embeds_list = []
div = args.batch * 3
temlen = projectors.shape[0] // div
for i in range(temlen):
st, ed = div * i, div * (i + 1)
tem_projectors = projectors[st: ed, :]
final_embeds_list.append(self.forward(tem_projectors))
if temlen * div < projectors.shape[0]:
tem_projectors = projectors[temlen*div:, :]
final_embeds_list.append(self.forward(tem_projectors))
final_embeds = t.concat(final_embeds_list, dim=0)
self.final_embeds = final_embeds
final_embeds = self.final_embeds
anc_embeds = final_embeds[ancs]
cand_embeds = final_embeds[-cand_size:]
mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size]))
dense_mat = mask_mat.to_dense()
all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8
return all_preds
def attempt(self, topo_embeds, dataset):
final_embeds = self.trainable_nn(topo_embeds)
rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs]))
if rows.shape[0] > args.attempt_cache:
random_perm = t.randperm(rows.shape[0], device=args.devices[0])
pck_perm = random_perm[:args.attempt_cache]
rows = rows[pck_perm]
cols = cols[pck_perm]
negs = negs[pck_perm]
while True:
try:
row_embeds = final_embeds[rows]
col_embeds = final_embeds[cols]
neg_embeds = final_embeds[negs]
score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item()
break
except Exception:
args.attempt_cache = args.attempt_cache // 2
random_perm = t.randperm(rows.shape[0], device=args.devices[0])
pck_perm = random_perm[:args.attempt_cache]
rows = rows[pck_perm]
cols = cols[pck_perm]
negs = negs[pck_perm]
t.cuda.empty_cache()
return score
class AnyGraph(nn.Module):
def __init__(self):
super(AnyGraph, self).__init__()
self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda()
self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts))
def assign_experts(self, handlers, reca=True, log_assignment=False):
if args.expert_num == 1:
self.assignment = [0] * len(handlers)
return
try:
expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts)))
expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2
except Exception:
expert_scores = np.ones(len(self.experts))
with t.no_grad():
assignment = [list() for i in range(len(handlers))]
for dataset_id, handler in enumerate(handlers):
topo_embeds = handler.projectors.to(args.devices[1])
for expert_id, expert in enumerate(self.experts):
expert = expert.to(args.devices[1])
score = expert.attempt(topo_embeds, handler.trn_loader.dataset)
if reca:
score *= expert_scores[expert_id]
assignment[dataset_id].append((expert_id, score))
assignment[dataset_id].sort(key=lambda x: x[1], reverse=True)
if log_assignment:
print('\n----------\nAssignment')
for dataset_id, handler in enumerate(handlers):
out = ''
for exp_idx in range(min(4, len(self.experts))):
out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) '
print(handler.data_name, out)
print('----------\n')
self.assignment = list(map(lambda x: x[0][0], assignment))
def summon(self, dataset_id):
return self.experts[self.assignment[dataset_id]]
def summon_opt(self, dataset_id):
return self.opts[self.assignment[dataset_id]]
================================================
FILE: node_classification/Utils/TimeLogger.py
================================================
import datetime
logmsg = ''
timemark = dict()
saveDefault = False
def log(msg, save=None, oneline=False):
global logmsg
global saveDefault
time = datetime.datetime.now()
tem = '%s: %s' % (time, msg)
if save != None:
if save:
logmsg += tem + '\n'
elif saveDefault:
logmsg += tem + '\n'
if oneline:
print(tem, end='\r')
else:
print(tem)
def marktime(marker):
global timemark
timemark[marker] = datetime.datetime.now()
if __name__ == '__main__':
log('')
================================================
FILE: node_classification/data_handler.py
================================================
import pickle
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
from params import args
import scipy.sparse as sp
from Utils.TimeLogger import log
import torch as t
import torch.utils.data as data
import torch_geometric.transforms as T
from model import Feat_Projector, Adj_Projector, TopoEncoder
import os
class MultiDataHandler:
def __init__(self, trn_datasets, tst_datasets_group):
all_datasets = trn_datasets
all_tst_datasets = []
for tst_datasets in tst_datasets_group:
all_datasets = all_datasets + tst_datasets
all_tst_datasets += tst_datasets
all_datasets = list(set(all_datasets))
all_datasets.sort()
self.trn_handlers = []
self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))]
for data_name in all_datasets:
trn_flag = data_name in trn_datasets
tst_flag = data_name in all_tst_datasets
handler = DataHandler(data_name)
if trn_flag:
self.trn_handlers.append(handler)
if tst_flag:
for i in range(len(tst_datasets_group)):
if data_name in tst_datasets_group[i]:
self.tst_handlers_group[i].append(handler)
self.make_joint_trn_loader()
def make_joint_trn_loader(self):
loader_datasets = []
for trn_handler in self.trn_handlers:
tem_dataset = trn_handler.trn_loader.dataset
loader_datasets.append(tem_dataset)
joint_dataset = JointTrnData(loader_datasets)
self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0)
def remake_initial_projections(self):
for i in range(len(self.trn_handlers)):
trn_handler = self.trn_handlers[i]
trn_handler.make_projectors()
class DataHandler:
def __init__(self, data_name):
self.data_name = data_name
self.get_data_files()
log(f'Loading dataset {data_name}')
self.topo_encoder = TopoEncoder()
self.load_data()
def get_data_files(self):
predir = f'/home/user_name/data/zero-shot datasets/node_data/{self.data_name}/'
# predir = f'../handle_node_data/{self.data_name}/'
if os.path.exists(predir + 'feats.pkl'):
self.feat_file = predir + 'feats.pkl'
else:
self.feat_file = None
self.trnfile = predir + 'trn_mat.pkl'
self.tstfile = predir + 'tst_mat.pkl'
self.fewshotfile = predir + 'fewshot_mat_{shot}.pkl'.format(shot=args.shot)
self.valfile = predir + 'val_mat.pkl'
def load_one_file(self, filename):
with open(filename, 'rb') as fs:
ret = (pickle.load(fs) != 0).astype(np.float32)
if type(ret) != coo_matrix:
ret = sp.coo_matrix(ret)
return ret
def load_feats(self, filename):
try:
with open(filename, 'rb') as fs:
feats = pickle.load(fs)
except Exception as e:
print(filename + str(e))
exit()
return feats
def normalize_adj(self, mat, log=False):
degree = np.array(mat.sum(axis=-1))
d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
if mat.shape[0] == mat.shape[1]:
return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()
else:
tem = d_inv_sqrt_mat.dot(mat)
col_degree = np.array(mat.sum(axis=0))
d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
return tem.dot(d_inv_sqrt_mat).tocoo()
def unique_numpy(self, row, col):
hash_vals = row * args.node_num + col
hash_vals = np.unique(hash_vals).astype(np.int64)
col = hash_vals % args.node_num
row = (hash_vals - col).astype(np.int64) // args.node_num
return row, col
def make_torch_adj(self, mat, unidirectional_for_asym=False):
if mat.shape[0] == mat.shape[1]:
_row = mat.row
_col = mat.col
row = np.concatenate([_row, _col]).astype(np.int64)
col = np.concatenate([_col, _row]).astype(np.int64)
row, col = self.unique_numpy(row, col)
data = np.ones_like(row)
mat = coo_matrix((data, (row, col)), mat.shape)
if args.selfloop == 1:
mat = (mat + sp.eye(mat.shape[0])) * 1.0
normed_asym_mat = self.normalize_adj(mat)
row = t.from_numpy(normed_asym_mat.row).long()
col = t.from_numpy(normed_asym_mat.col).long()
idxs = t.stack([row, col], dim=0)
vals = t.from_numpy(normed_asym_mat.data).float()
shape = t.Size(normed_asym_mat.shape)
asym_adj = t.sparse.FloatTensor(idxs, vals, shape)
return asym_adj
elif unidirectional_for_asym:
mat = (mat != 0) * 1.0
mat = self.normalize_adj(mat, log=True)
idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
vals = t.from_numpy(mat.data.astype(np.float32))
shape = t.Size(mat.shape)
return t.sparse.FloatTensor(idxs, vals, shape)
else:
# make ui adj
a = sp.csr_matrix((args.user_num, args.user_num))
b = sp.csr_matrix((args.item_num, args.item_num))
mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
mat = (mat != 0) * 1.0
if args.selfloop == 1:
mat = (mat + sp.eye(mat.shape[0])) * 1.0
mat = self.normalize_adj(mat)
# make cuda tensor
idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
vals = t.from_numpy(mat.data.astype(np.float32))
shape = t.Size(mat.shape)
return t.sparse.FloatTensor(idxs, vals, shape)
def load_data(self):
tst_mat = self.load_one_file(self.tstfile)
trn_mat = self.load_one_file(self.trnfile)
# fewshot_mat = self.load_one_file(self.fewshotfile)
if self.feat_file is not None:
self.feats = t.from_numpy(self.load_feats(self.feat_file)).float()
self.feats = self.feats
args.featdim = self.feats.shape[1]
else:
self.feats = None
args.featdim = args.latdim
if trn_mat.shape[0] != trn_mat.shape[1]:
args.user_num, args.item_num = trn_mat.shape
args.node_num = args.user_num + args.item_num
print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz))
else:
args.node_num = trn_mat.shape[0]
print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+tst_mat.nnz))
if args.tst_mode == 'tst':
tst_data = NodeTstData(tst_mat)
self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
# tst_loss_data = TrnData(tst_mat)
# self.tst_loss_loader = data.DataLoader(tst_loss_data, batch_size=args.batch, shuffle=False, num_workers=0)
self.tst_input_adj = self.make_torch_adj(trn_mat)
elif args.tst_mode == 'val':
raise Exception('Validation not made for node classification')
else:
raise Exception('Specify proper test mode')
self.trn_mat = trn_mat
trn_data = TrnData(self.trn_mat)
self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)
if args.tst_mode == 'tst':
self.trn_input_adj = self.tst_input_adj
else:
self.trn_input_adj = self.make_torch_adj(trn_mat)
if self.trn_mat.shape[0] == self.trn_mat.shape[1]:
self.asym_adj = self.trn_input_adj
else:
self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True)
self.make_projectors()
self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps)
self.ratio_500_all = 250 / len(self.trn_loader)
def make_projectors(self):
with t.no_grad():
projectors = []
if args.proj_method == 'adj_svd' or args.proj_method == 'both':
tem = self.asym_adj.to(args.devices[0])
projectors = [Adj_Projector(tem)]
if self.feats is not None and args.proj_method != 'adj_svd':
tem = self.feats.to(args.devices[0])
projectors.append(Feat_Projector(tem))
assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot'
feats = projectors[0]()
if len(projectors) == 2:
feats2 = projectors[1]()
feats = feats + feats2
try:
self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu()
except Exception:
print(f'{self.data_name} memory overflow')
mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True)
tem_adj = self.trn_input_adj.to(args.devices[0])
mem_cache = 256
projectors_list = []
for i in range(feats.shape[1] // mem_cache):
st, ed = i * mem_cache, (i + 1) * mem_cache
tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8)
tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu()
projectors_list.append(tem_feats)
self.projectors = t.concat(projectors_list, dim=-1)
t.cuda.empty_cache()
class TrnData(data.Dataset):
def __init__(self, coomat):
self.ancs, self.poss = coomat.row, coomat.col
# self.dokmat = set(list(map(lambda idx: (self.ancs[idx], self.poss[idx]), range(len(self.ancs)))))
# self.dokmat = coomat.todok()
self.negs = np.zeros(len(self.ancs)).astype(np.int32)
self.cand_num = coomat.shape[1]
self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0]
self.poss = coomat.col + self.neg_shift
self.neg_sampling()
def neg_sampling(self):
self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0])
# self.negs = np.zeros_like(self.ancs)
# for i in range(len(self.ancs)):
# u = self.ancs[i]
# while True:
# i_neg = np.random.randint(self.cand_num)
# if (u, i_neg) not in self.dokmat:
# break
# self.negs[i] = i_neg
# self.negs += self.neg_shift
def __len__(self):
return len(self.ancs)
def __getitem__(self, idx):
return self.ancs[idx], self.poss[idx] , self.negs[idx]
class NodeTstData(data.Dataset):
def __init__(self, tst_mat):
self.class_num = tst_mat.shape[1]
self.nodes, self.labels = tst_mat.row, tst_mat.col
def __len__(self):
return len(self.nodes)
def __getitem__(self, idx):
return self.nodes[idx], self.labels[idx]
class JointTrnData(data.Dataset):
def __init__(self, dataset_list):
self.batch_dataset_ids = []
self.batch_st_ed_list = []
self.dataset_list = dataset_list
for dataset_id, dataset in enumerate(dataset_list):
samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0)
for j in range(samp_num):
self.batch_dataset_ids.append(dataset_id)
st = j * args.batch
ed = min((j + 1) * args.batch, len(dataset))
self.batch_st_ed_list.append((st, ed))
def neg_sampling(self):
for dataset in self.dataset_list:
dataset.neg_sampling()
def __len__(self):
return len(self.batch_dataset_ids)
def __getitem__(self, idx):
st, ed = self.batch_st_ed_list[idx]
dataset_id = self.batch_dataset_ids[idx]
return *self.dataset_list[dataset_id][st: ed], dataset_id
================================================
FILE: node_classification/main.py
================================================
import torch as t
from torch import nn
import Utils.TimeLogger as logger
from Utils.TimeLogger import log
from params import args
from model import Expert, Feat_Projector, Adj_Projector, AnyGraph
from data_handler import MultiDataHandler, DataHandler
import numpy as np
import pickle
import os
import setproctitle
import time
from sklearn.metrics import f1_score
class Exp:
def __init__(self, multi_handler):
self.multi_handler = multi_handler
print(list(map(lambda x: x.data_name, multi_handler.trn_handlers)))
for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group):
print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers)))
self.metrics = dict()
trn_mets = ['Loss', 'preLoss']
tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss']
mets = trn_mets + tst_mets
for met in mets:
if met in trn_mets:
self.metrics['Train' + met] = list()
if met in tst_mets:
for i in range(len(self.multi_handler.tst_handlers_group)):
self.metrics['Test' + str(i) + met] = list()
def make_print(self, name, ep, reses, save, data_name=None):
if data_name is None:
ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
else:
ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)
for metric in reses:
val = reses[metric]
ret += '%s = %.4f, ' % (metric, val)
tem = name + metric if data_name is None else name + data_name + metric
if save and tem in self.metrics:
self.metrics[tem].append(val)
ret = ret[:-2] + ' '
return ret
def run(self):
self.prepare_model()
log('Model Prepared')
stloc = 0
if args.load_model != None:
self.load_model()
stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)
best_ndcg, best_ep = 0, -1
early_stop = False
for ep in range(stloc, args.epoch):
tst_flag = (ep % args.tst_epoch == 0)
start_time = time.time()
self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True)
reses = self.train_epoch()
log(self.make_print('Train', ep, reses, tst_flag))
self.multi_handler.remake_initial_projections()
end_time = time.time()
print(f'NOTICE: {end_time-start_time}')
if tst_flag:
for handler_group_id in range(len(self.multi_handler.tst_handlers_group)):
tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id]
self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)
recall, ndcg, tstnum = 0, 0, 0
for i, handler in enumerate(tst_handlers):
reses = self.test_epoch(handler, i)
log(self.make_print('handler.data_name', ep, reses, False))
recall += reses['Recall'] * reses['tstNum']
ndcg += reses['NDCG'] * reses['tstNum']
tstnum += reses['tstNum']
reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum}
log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag))
if reses['NDCG'] > best_ndcg:
best_ndcg = reses['NDCG']
best_ep = ep
self.save_history()
print()
for test_group_id in range(len(self.multi_handler.tst_handlers_group)):
repeat_times = 10
overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times)
overall_tstnum = 0
tst_handlers = self.multi_handler.tst_handlers_group[test_group_id]
if args.assignment == 'one-graph-one-expert':
self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)
for i, handler in enumerate(tst_handlers):
mets = dict()
for _ in range(repeat_times):
handler.make_projectors()
if not args.assignment == 'one-graph-one-expert':
self.model.assign_experts([handler], reca=False, log_assignment=False)
reses = self.test_epoch(handler, i if args.assignment == 'one-graph-one-expert' else 0)
log(self.make_print('Test', args.epoch, reses, False))
for met in reses:
if met not in mets:
mets[met] = []
mets[met].append(reses[met])
tstnum = reses['tstNum']
tot_reses = dict()
for met in reses:
tem_arr = np.array(mets[met])
tot_reses[met + '_std'] = tem_arr.std()
tot_reses[met + '_mean'] = tem_arr.mean()
overall_recall += np.array(mets['Acc']) * tstnum
overall_ndcg += np.array(mets['F1']) * tstnum
overall_tstnum += tstnum
log(self.make_print(f'Test', args.epoch, tot_reses, False, handler.data_name))
overall_recall /= overall_tstnum
overall_ndcg /= overall_tstnum
overall_res = dict()
overall_res['Recall_mean'] = overall_recall.mean()
overall_res['Recall_std'] = overall_recall.std()
overall_res['NDCG_mean'] = overall_ndcg.mean()
overall_res['NDCG_std'] = overall_ndcg.std()
log(self.make_print('Overall Test', args.epoch, overall_res, False))
self.save_history()
def print_model_size(self):
total_params = 0
trainable_params = 0
non_trainable_params = 0
for param in self.model.parameters():
tem = np.prod(param.size())
total_params += tem
if param.requires_grad:
trainable_params += tem
else:
non_trainable_params += tem
print(f'Total params: {total_params/1e6}')
print(f'Trainable params: {trainable_params/1e6}')
print(f'Non-trainable params: {non_trainable_params/1e6}')
def prepare_model(self):
self.model = AnyGraph()
t.cuda.empty_cache()
self.print_model_size()
def train_epoch(self):
self.model.train()
trn_loader = self.multi_handler.joint_trn_loader
trn_loader.dataset.neg_sampling()
ep_loss, ep_preloss, ep_regloss = 0, 0, 0
steps = len(trn_loader)
tot_samp_num = 0
counter = [0] * len(self.multi_handler.trn_handlers)
reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers)))
for i, batch_data in enumerate(trn_loader):
if args.epoch_max_step > 0 and i >= args.epoch_max_step:
break
ancs, poss, negs, dataset_id = batch_data
ancs = ancs[0].long()
poss = poss[0].long()
negs = negs[0].long()
dataset_id = dataset_id[0].long()
tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all
if tem_bar < 1.0 and np.random.uniform() > tem_bar:
steps -= 1
continue
expert = self.model.summon(dataset_id)#.cuda()
opt = self.model.summon_opt(dataset_id)
# adj = self.multi_handler.trn_handlers[dataset_id].trn_input_adj
feats = self.multi_handler.trn_handlers[dataset_id].projectors
loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)
opt.zero_grad()
loss.backward()
# nn.utils.clip_grad_norm_(expert.parameters(), max_norm=20, norm_type=2)
opt.step()
sample_num = ancs.shape[0]
tot_samp_num += sample_num
ep_loss += loss.item() * sample_num
ep_preloss += loss_dict['preloss'].item() * sample_num
ep_regloss += loss_dict['regloss'].item()
log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)
counter[dataset_id] += 1
if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0:
# if args.proj_trn_steps > 0 and counter[dataset_id] >= args.proj_trn_steps:
self.multi_handler.trn_handlers[dataset_id].make_projectors()
if (i + 1) % reassign_steps == 0:
self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False)
ret = dict()
ret['Loss'] = ep_loss / tot_samp_num
ret['preLoss'] = ep_preloss / tot_samp_num
ret['regLoss'] = ep_regloss / steps
t.cuda.empty_cache()
return ret
def make_trn_masks(self, numpy_usrs, csr_mat):
trn_masks = csr_mat[numpy_usrs].tocoo()
cand_size = trn_masks.shape[1]
trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long()
return trn_masks, cand_size
def test_loss_epoch(self, handler, dataset_id):
with t.no_grad():
tst_loader = handler.tst_loss_loader
self.model.eval()
expert = self.model.summon(dataset_id)#.cuda()
ep_loss, ep_preloss, ep_regloss = 0, 0, 0
steps = len(tst_loader)
tot_samp_num = 0
for i, batch_data in enumerate(tst_loader):
ancs, poss, negs = batch_data
ancs = ancs.long()
poss = poss.long()
negs = negs.long()
# adj = handler.tst_input_adj
feats = handler.projectors
loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)
sample_num = ancs.shape[0]
tot_samp_num += sample_num
ep_loss += loss.item() * sample_num
ep_preloss += loss_dict['preloss'].item() * sample_num
ep_regloss += loss_dict['regloss'].item()
log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)
ret = dict()
ret['Loss'] = ep_loss / tot_samp_num
ret['preLoss'] = ep_preloss / tot_samp_num
ret['regLoss'] = ep_regloss / steps
ret['tot_samp_num'] = tot_samp_num
t.cuda.empty_cache()
return ret
def test_epoch(self, handler, dataset_id):
with t.no_grad():
tst_loader = handler.tst_loader
class_num = tst_loader.dataset.class_num
self.model.eval()
expert = self.model.summon(dataset_id)
ep_acc, ep_tot = 0, 0
steps = len(tst_loader)
for i, batch_data in enumerate(tst_loader):
nodes, labels = list(map(lambda x: x.long().cuda(), batch_data))
feats = handler.projectors
preds = expert.pred_for_node_test(nodes, class_num, feats, rerun_embed=False if i!=0 else True)
if i == 0:
all_preds, all_labels = preds, labels
else:
all_preds = t.concatenate([all_preds, preds])
all_labels = t.concatenate([all_labels, labels])
hit = (labels == preds).float().sum().item()
ep_acc += hit
ep_tot += labels.shape[0]
log('Steps %d/%d: hit = %d, tot = %d ' % (i, steps, ep_acc, ep_tot), save=False, oneline=True)
ret = dict()
ret['Acc'] = ep_acc / ep_tot
ret['F1'] = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
ret['tstNum'] = ep_tot
t.cuda.empty_cache()
return ret
def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
assert topLocs.shape[0] == len(batIds)
allRecall = allNdcg = 0
for i in range(len(batIds)):
temTopLocs = list(topLocs[i])
temTstLocs = tstLocs[batIds[i]]
tstNum = len(temTstLocs)
maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])
recall = dcg = 0
for val in temTstLocs:
if val in temTopLocs:
recall += 1
dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
recall = recall / tstNum
ndcg = dcg / maxDcg
allRecall += recall
allNdcg += ndcg
return allRecall, allNdcg
def save_history(self):
if args.epoch == 0:
return
with open('../History/' + args.save_path + '.his', 'wb') as fs:
pickle.dump(self.metrics, fs)
content = {
'model': self.model,
}
t.save(content, '../Models/' + args.save_path + '.mod')
log('Model Saved: %s' % args.save_path)
def load_model(self):
ckp = t.load('../Models/' + args.load_model + '.mod')
self.model = ckp['model']
# self.model.set_initial_projection(self.handler.torch_adj)
self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)
with open('../History/' + args.load_model + '.his', 'rb') as fs:
self.metrics = pickle.load(fs)
log('Model Loaded')
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if len(args.gpu.split(',')) == 2:
args.devices = ['cuda:0', 'cuda:1']
elif len(args.gpu.split(',')) > 2:
raise Exception('Devices should be less than 2')
else:
args.devices = ['cuda:0', 'cuda:0']
logger.saveDefault = True
setproctitle.setproctitle('akaxia_AutoGraph')
log('Start')
datasets = dict()
datasets['all'] = [
'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1',
]
datasets['ecommerce'] = [
'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech'
]
datasets['academic'] = [
'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab'
]
datasets['others'] = [
'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'
]
datasets['div1'] = [
'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron',
]
datasets['div2'] = [
'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA',
]
datasets['node'] = [
'cora', 'arxiv', 'pubmed', 'home', 'tech'
]
if args.dataset_setting in datasets.keys():
trn_datasets = tst_datasets = datasets[args.dataset_setting]
elif args.dataset_setting in datasets['all']:
trn_datasets = tst_datasets = [args.dataset_setting]
elif '+' in args.dataset_setting:
idx = args.dataset_setting.index('+')
trn_datasets = datasets[args.dataset_setting[:idx]]
tst_datasets = datasets[args.dataset_setting[idx+1:]]
elif '_in_' in args.dataset_setting:
idx = args.dataset_setting.index('_in_')
tst_datasets_1 = datasets[args.dataset_setting[:idx]]
tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]]
tst_datasets = []
for data in tst_datasets_1:
if data in tst_datasets_2:
tst_datasets.append(data)
trn_datasets = tst_datasets
# trn_datasets = tst_datasets = ['products_home']
if '+' not in args.dataset_setting:
handler = MultiDataHandler(trn_datasets, [tst_datasets])
else:
handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets])
log('Load Data')
exp = Exp(handler)
exp.run()
print(args.load_model, args.dataset_setting)
================================================
FILE: node_classification/model.py
================================================
import torch as t
from torch import nn
import torch.nn.functional as F
from params import args
import numpy as np
from Utils.TimeLogger import log
from torch.nn import MultiheadAttention
from time import time
init = nn.init.xavier_uniform_
# init = nn.init.normal_
uniformInit = nn.init.uniform_
class FeedForwardLayer(nn.Module):
def __init__(self, in_feat, out_feat, bias=True, act=None):
super(FeedForwardLayer, self).__init__()
self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)
if act == 'identity' or act is None:
self.act = None
elif act == 'leaky':
self.act = nn.LeakyReLU(negative_slope=args.leaky)
elif act == 'relu':
self.act = nn.ReLU()
elif act == 'relu6':
self.act = nn.ReLU6()
else:
raise Exception('Error')
def forward(self, embeds):
if self.act is None:
return self.linear(embeds)
return self.act(self.linear(embeds))
class TopoEncoder(nn.Module):
def __init__(self):
super(TopoEncoder, self).__init__()
self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)
def forward(self, adj, embeds, normed=False):
with t.no_grad():
if not normed:
embeds = self.layer_norm(embeds)
# embeds_list = []
final_embeds = 0
if args.gnn_layer == 0:
final_embeds = embeds
# embeds_list.append(embeds)
for _ in range(args.gnn_layer):
embeds = t.spmm(adj, embeds)
final_embeds += embeds
# embeds_list.append(embeds)
embeds = final_embeds#sum(embeds_list)
return embeds
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)])
self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)])
self.dropout = nn.Dropout(p=args.drop_rate)
def forward(self, embeds):
for i in range(args.fc_layer):
embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds)
return embeds
class GTLayer(nn.Module):
def __init__(self):
super(GTLayer, self).__init__()
self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)
self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False
self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
self.fc_dropout = nn.Dropout(p=args.drop_rate)
def _pick_anchors(self, embeds):
perm = t.randperm(embeds.shape[0])
anchors = perm[:args.anchor]
return embeds[anchors]
def forward(self, embeds):
anchor_embeds = self._pick_anchors(embeds)
_anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)
anchor_embeds = _anchor_embeds + anchor_embeds
_embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)
embeds = self.layer_norm1(_embeds + embeds)
_embeds = self.fc_dropout(self.dense_layers(embeds))
embeds = (self.layer_norm2(_embeds + embeds))
return embeds
class GraphTransformer(nn.Module):
def __init__(self):
super(GraphTransformer, self).__init__()
self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])
def forward(self, embeds):
for i, layer in enumerate(self.gt_layers):
embeds = layer(embeds) / args.scale_layer
return embeds
class Feat_Projector(nn.Module):
def __init__(self, feats):
super(Feat_Projector, self).__init__()
if args.proj_method == 'uniform':
self.proj_embeds = self.uniform_proj(feats)
elif args.proj_method == 'svd' or args.proj_method == 'both':
self.proj_embeds = self.svd_proj(feats)
elif args.proj_method == 'random':
self.proj_embeds = self.random_proj(feats)
elif args.proj_method == 'original':
self.proj_embeds = feats
self.proj_embeds = t.flip(self.proj_embeds, dims=[-1])
self.proj_embeds = self.proj_embeds.detach()
def svd_proj(self, feats):
if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]:
dim = min(feats.shape[0], feats.shape[1])
decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter)
decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])])
else:
decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter)
decom_feats = decom_feats @ t.diag(t.sqrt(s))
return decom_feats.cpu()
def uniform_proj(self, feats):
projection = init(t.empty(args.featdim, args.latdim))
return feats @ projection
def random_proj(self, feats):
projection = init(t.empty(feats.shape[0], args.latdim))
return projection
def forward(self):
return self.proj_embeds
class Adj_Projector(nn.Module):
def __init__(self, adj):
super(Adj_Projector, self).__init__()
if args.proj_method == 'adj_svd' or args.proj_method == 'both':
self.proj_embeds = self.svd_proj(adj)
self.proj_embeds = self.proj_embeds.detach()
def svd_proj(self, adj):
q = args.latdim
if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:
dim = min(adj.shape[0], adj.shape[1])
svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)
svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])
else:
svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)
svd_u = svd_u @ t.diag(t.sqrt(s))
svd_v = svd_v @ t.diag(t.sqrt(s))
if adj.shape[0] != adj.shape[1]:
projection = t.concat([svd_u, svd_v], dim=0)
else:
projection = svd_u + svd_v
return projection.cpu()
def forward(self):
return self.proj_embeds
class Expert(nn.Module):
def __init__(self):
super(Expert, self).__init__()
self.topo_encoder = TopoEncoder().to(args.devices[0])
if args.nn == 'mlp':
self.trainable_nn = MLP().to(args.devices[1])
else:
self.trainable_nn = GraphTransformer().to(args.devices[1])
self.trn_count = 1
def forward(self, projectors, pck_nodes=None):
embeds = projectors.to(args.devices[1])
if pck_nodes is not None:
embeds = embeds[pck_nodes]
embeds = self.trainable_nn(embeds)
return embeds
def pred_norm(self, pos_preds, neg_preds):
pos_preds_num = pos_preds.shape[0]
neg_preds_shape = neg_preds.shape
preds = t.concat([pos_preds, neg_preds.view(-1)])
preds = preds - preds.max()
pos_preds = preds[:pos_preds_num]
neg_preds = preds[pos_preds_num:].view(neg_preds_shape)
return pos_preds, neg_preds
def cal_loss(self, batch_data, projectors):
ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data))
self.trn_count += ancs.shape[0]
pck_nodes = t.concat([ancs, poss, negs])
final_embeds = self.forward(projectors, pck_nodes)
# anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]
anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3)
if final_embeds.isinf().any() or final_embeds.isnan().any():
raise Exception('Final embedding fails')
if args.loss == 'ce':
pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)
if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():
raise Exception('Preds fails')
pos_loss = pos_preds
neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()
pre_loss = -(pos_loss - neg_loss).mean()
elif args.loss == 'bpr':
pos_preds = (anc_embeds * pos_embeds).sum(-1)
neg_preds = (anc_embeds * neg_embeds).sum(-1)
pos_loss, neg_loss = pos_preds, neg_preds
pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean()
if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():
raise Exception('NaN or Inf')
reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))
loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}
return pre_loss + reg_loss, loss_dict
def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True):
ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data))
if rerun_embed:
try:
final_embeds = self.forward(projectors)
except Exception:
final_embeds_list = []
div = args.batch * 3
temlen = projectors.shape[0] // div
for i in range(temlen):
st, ed = div * i, div * (i + 1)
tem_projectors = projectors[st: ed, :]
final_embeds_list.append(self.forward(tem_projectors))
if temlen * div < projectors.shape[0]:
tem_projectors = projectors[temlen*div:, :]
final_embeds_list.append(self.forward(tem_projectors))
final_embeds = t.concat(final_embeds_list, dim=0)
self.final_embeds = final_embeds
final_embeds = self.final_embeds
anc_embeds = final_embeds[ancs]
cand_embeds = final_embeds[-cand_size:]
mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size]))
dense_mat = mask_mat.to_dense()
all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8
return all_preds
def pred_for_node_test(self, nodes, cand_size, feats, rerun_embed=True):
if rerun_embed:
final_embeds = self.forward(feats)
self.final_embeds = final_embeds
final_embeds = self.final_embeds
anc_embeds = final_embeds[nodes]
class_embeds = final_embeds[-cand_size:]
preds = anc_embeds @ class_embeds.T
return t.argmax(preds, dim=-1)
def attempt(self, topo_embeds, dataset):
final_embeds = self.trainable_nn(topo_embeds)
rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs]))
if rows.shape[0] > args.attempt_cache:
random_perm = t.randperm(rows.shape[0], device=args.devices[0])
pck_perm = random_perm[:args.attempt_cache]
rows = rows[pck_perm]
cols = cols[pck_perm]
negs = negs[pck_perm]
while True:
try:
row_embeds = final_embeds[rows]
col_embeds = final_embeds[cols]
neg_embeds = final_embeds[negs]
score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item()
break
except Exception:
args.attempt_cache = args.attempt_cache // 2
random_perm = t.randperm(rows.shape[0], device=args.devices[0])
pck_perm = random_perm[:args.attempt_cache]
rows = rows[pck_perm]
cols = cols[pck_perm]
negs = negs[pck_perm]
t.cuda.empty_cache()
return score
class AnyGraph(nn.Module):
def __init__(self):
super(AnyGraph, self).__init__()
self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda()
self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts))
def one_graph_one_expert(self, handlers, log_assignment=False):
self.assignment = [i for i in range(len(handlers))]
if log_assignment:
print('\n----------\nAssignment')
for dataset_id, handler in enumerate(handlers):
print(handler.data_name, f'{self.assignment[dataset_id]}')
def store_history_assignment(self, assignment):
pass
def assign_experts(self, handlers, reca=True, log_assignment=False):
if args.expert_num == 1:
self.assignment = [0] * len(handlers)
return
if args.assignment == 'one-graph-one-expert':
self.one_graph_one_expert(handlers, log_assignment)
return
if args.assignment not in ['top1', 'top2']:
raise Exception('Unrecognized assigning methods')
try:
expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts)))
expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2
except Exception:
expert_scores = np.ones(len(self.experts))
with t.no_grad():
assignment = [list() for i in range(len(handlers))]
for dataset_id, handler in enumerate(handlers):
topo_embeds = handler.projectors.to(args.devices[1])
for expert_id, expert in enumerate(self.experts):
expert = expert.to(args.devices[1])
score = expert.attempt(topo_embeds, handler.trn_loader.dataset)
if reca:
score *= expert_scores[expert_id]
assignment[dataset_id].append((expert_id, score))
assignment[dataset_id].sort(key=lambda x: x[1], reverse=True)
if args.assignment == 'top2':
if assignment[dataset_id][0][1] - assignment[dataset_id][1][1] < 0.05 and np.random.uniform() > 0.5:
tem = assignment[dataset_id][0]
assignment[dataset_id][0] = assignment[dataset_id][1]
assignment[dataset_id][1] = tem
if log_assignment:
print('\n----------\nAssignment')
for dataset_id, handler in enumerate(handlers):
out = ''
for exp_idx in range(min(4, len(self.experts))):
out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) '
print(handler.data_name, out)
print('----------\n')
self.assignment = list(map(lambda x: x[0][0], assignment))
def summon(self, dataset_id):
return self.experts[self.assignment[dataset_id]]
def summon_opt(self, dataset_id):
return self.opts[self.assignment[dataset_id]]
================================================
FILE: node_classification/params.py
================================================
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Model Parameters')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--batch', default=4096, type=int, help='training batch size')
parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')
parser.add_argument('--epoch', default=100, type=int, help='number of epochs')
parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
parser.add_argument('--load_model', default=None, help='model name to load')
parser.add_argument('--data', default='ml1m', type=str, help='name of dataset')
parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')
parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')
parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')
parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')
parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]')
parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]')
parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance')
parser.add_argument('--ratio_fewshot_set', default=0.5, type=float, help='ratio of fewshot set')
parser.add_argument('--shot', default=5, type=int, help='number of shots for each node')
parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all')
parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')
parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality')
parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers')
parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers')
parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers')
parser.add_argument('--head', default=4, type=int, help='number of attention heads')
parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')
parser.add_argument('--act', default='relu', type=str, help='activation function')
parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use')
parser.add_argument('--assignment', default='top1', type=str, help='assigning method')
parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')
parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')
parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout')
parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration')
parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')
parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd')
parser.add_argument('--expert_num', default=8, type=int, help='number of experts')
parser.add_argument('--loss', default='ce', type=str, help='loss function')
parser.add_argument('--proj_method', default='both', type=str, help='feature projection method')
parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use')
parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection')
parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection')
return parser.parse_args()
args = parse_args()
================================================
FILE: params.py
================================================
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Model Parameters')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--batch', default=4096, type=int, help='training batch size')
parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')
parser.add_argument('--epoch', default=100, type=int, help='number of epochs')
parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
parser.add_argument('--load_model', default=None, help='model name to load')
parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')
parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')
parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')
parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')
parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]')
parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]')
parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance')
parser.add_argument('--ratio_fewshot', default=0.1, type=float, help='ratio of fewshot set')
parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all')
parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')
parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality')
parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers')
parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers')
parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers')
parser.add_argument('--head', default=4, type=int, help='number of attention heads')
parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')
parser.add_argument('--act', default='relu', type=str, help='activation function')
parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use')
parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')
parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')
parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout')
parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration')
parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')
parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd')
parser.add_argument('--expert_num', default=8, type=int, help='number of experts')
parser.add_argument('--loss', default='ce', type=str, help='loss function')
parser.add_argument('--proj_method', default='both', type=str, help='feature projection method')
parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use')
parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection')
parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection')
return parser.parse_args()
args = parse_args()
gitextract_t0oakqy3/ ├── .gitignore ├── History/ │ ├── pretrain_link1.his │ └── pretrain_link2.his ├── Models/ │ └── README.md ├── README.md ├── Utils/ │ └── TimeLogger.py ├── data_handler.py ├── main.py ├── model.py ├── node_classification/ │ ├── Utils/ │ │ └── TimeLogger.py │ ├── data_handler.py │ ├── main.py │ ├── model.py │ └── params.py └── params.py
SYMBOL INDEX (166 symbols across 10 files)
FILE: Utils/TimeLogger.py
function log (line 6) | def log(msg, save=None, oneline=False):
function marktime (line 21) | def marktime(marker):
FILE: data_handler.py
class MultiDataHandler (line 13) | class MultiDataHandler:
method __init__ (line 14) | def __init__(self, trn_datasets, tst_datasets_group):
method make_joint_trn_loader (line 36) | def make_joint_trn_loader(self):
method remake_initial_projections (line 44) | def remake_initial_projections(self):
class DataHandler (line 49) | class DataHandler:
method __init__ (line 50) | def __init__(self, data_name):
method get_data_files (line 57) | def get_data_files(self):
method load_one_file (line 68) | def load_one_file(self, filename):
method load_feats (line 75) | def load_feats(self, filename):
method normalize_adj (line 84) | def normalize_adj(self, mat, log=False):
method unique_numpy (line 99) | def unique_numpy(self, row, col):
method make_torch_adj (line 106) | def make_torch_adj(self, mat, unidirectional_for_asym=False):
method load_data (line 148) | def load_data(self):
method make_projectors (line 207) | def make_projectors(self):
class TstData (line 238) | class TstData(data.Dataset):
method __init__ (line 239) | def __init__(self, coomat, trn_mat):
method __len__ (line 254) | def __len__(self):
method __getitem__ (line 257) | def __getitem__(self, idx):
class TrnData (line 260) | class TrnData(data.Dataset):
method __init__ (line 261) | def __init__(self, coomat):
method neg_sampling (line 269) | def neg_sampling(self):
method __len__ (line 272) | def __len__(self):
method __getitem__ (line 275) | def __getitem__(self, idx):
class JointTrnData (line 278) | class JointTrnData(data.Dataset):
method __init__ (line 279) | def __init__(self, dataset_list):
method neg_sampling (line 291) | def neg_sampling(self):
method __len__ (line 295) | def __len__(self):
method __getitem__ (line 298) | def __getitem__(self, idx):
FILE: main.py
class Exp (line 14) | class Exp:
method __init__ (line 15) | def __init__(self, multi_handler):
method make_print (line 31) | def make_print(self, name, ep, reses, save, data_name=None):
method run (line 45) | def run(self):
method print_model_size (line 120) | def print_model_size(self):
method prepare_model (line 135) | def prepare_model(self):
method train_epoch (line 140) | def train_epoch(self):
method make_trn_masks (line 189) | def make_trn_masks(self, numpy_usrs, csr_mat):
method test_epoch (line 195) | def test_epoch(self, handler, dataset_id):
method calc_recall_ndcg (line 226) | def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
method save_history (line 245) | def save_history(self):
method load_model (line 257) | def load_model(self):
FILE: model.py
class FeedForwardLayer (line 13) | class FeedForwardLayer(nn.Module):
method __init__ (line 14) | def __init__(self, in_feat, out_feat, bias=True, act=None):
method forward (line 28) | def forward(self, embeds):
class TopoEncoder (line 33) | class TopoEncoder(nn.Module):
method __init__ (line 34) | def __init__(self):
method forward (line 39) | def forward(self, adj, embeds, normed=False):
class MLP (line 55) | class MLP(nn.Module):
method __init__ (line 56) | def __init__(self):
method forward (line 62) | def forward(self, embeds):
class GTLayer (line 67) | class GTLayer(nn.Module):
method __init__ (line 68) | def __init__(self):
method _pick_anchors (line 76) | def _pick_anchors(self, embeds):
method forward (line 81) | def forward(self, embeds):
class GraphTransformer (line 91) | class GraphTransformer(nn.Module):
method __init__ (line 92) | def __init__(self):
method forward (line 96) | def forward(self, embeds):
class Feat_Projector (line 101) | class Feat_Projector(nn.Module):
method __init__ (line 102) | def __init__(self, feats):
method svd_proj (line 116) | def svd_proj(self, feats):
method uniform_proj (line 127) | def uniform_proj(self, feats):
method random_proj (line 131) | def random_proj(self, feats):
method forward (line 135) | def forward(self):
class Adj_Projector (line 138) | class Adj_Projector(nn.Module):
method __init__ (line 139) | def __init__(self, adj):
method svd_proj (line 146) | def svd_proj(self, adj):
method forward (line 164) | def forward(self):
class Expert (line 167) | class Expert(nn.Module):
method __init__ (line 168) | def __init__(self):
method forward (line 178) | def forward(self, projectors, pck_nodes=None):
method pred_norm (line 185) | def pred_norm(self, pos_preds, neg_preds):
method cal_loss (line 194) | def cal_loss(self, batch_data, projectors):
method pred_for_test (line 224) | def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed...
method attempt (line 251) | def attempt(self, topo_embeds, dataset):
class AnyGraph (line 277) | class AnyGraph(nn.Module):
method __init__ (line 278) | def __init__(self):
method assign_experts (line 283) | def assign_experts(self, handlers, reca=True, log_assignment=False):
method summon (line 314) | def summon(self, dataset_id):
method summon_opt (line 317) | def summon_opt(self, dataset_id):
FILE: node_classification/Utils/TimeLogger.py
function log (line 6) | def log(msg, save=None, oneline=False):
function marktime (line 21) | def marktime(marker):
FILE: node_classification/data_handler.py
class MultiDataHandler (line 13) | class MultiDataHandler:
method __init__ (line 14) | def __init__(self, trn_datasets, tst_datasets_group):
method make_joint_trn_loader (line 36) | def make_joint_trn_loader(self):
method remake_initial_projections (line 44) | def remake_initial_projections(self):
class DataHandler (line 49) | class DataHandler:
method __init__ (line 50) | def __init__(self, data_name):
method get_data_files (line 57) | def get_data_files(self):
method load_one_file (line 69) | def load_one_file(self, filename):
method load_feats (line 76) | def load_feats(self, filename):
method normalize_adj (line 85) | def normalize_adj(self, mat, log=False):
method unique_numpy (line 100) | def unique_numpy(self, row, col):
method make_torch_adj (line 107) | def make_torch_adj(self, mat, unidirectional_for_asym=False):
method load_data (line 149) | def load_data(self):
method make_projectors (line 196) | def make_projectors(self):
class TrnData (line 227) | class TrnData(data.Dataset):
method __init__ (line 228) | def __init__(self, coomat):
method neg_sampling (line 238) | def neg_sampling(self):
method __len__ (line 250) | def __len__(self):
method __getitem__ (line 253) | def __getitem__(self, idx):
class NodeTstData (line 256) | class NodeTstData(data.Dataset):
method __init__ (line 257) | def __init__(self, tst_mat):
method __len__ (line 261) | def __len__(self):
method __getitem__ (line 264) | def __getitem__(self, idx):
class JointTrnData (line 267) | class JointTrnData(data.Dataset):
method __init__ (line 268) | def __init__(self, dataset_list):
method neg_sampling (line 280) | def neg_sampling(self):
method __len__ (line 284) | def __len__(self):
method __getitem__ (line 287) | def __getitem__(self, idx):
FILE: node_classification/main.py
class Exp (line 15) | class Exp:
method __init__ (line 16) | def __init__(self, multi_handler):
method make_print (line 32) | def make_print(self, name, ep, reses, save, data_name=None):
method run (line 46) | def run(self):
method print_model_size (line 124) | def print_model_size(self):
method prepare_model (line 139) | def prepare_model(self):
method train_epoch (line 144) | def train_epoch(self):
method make_trn_masks (line 196) | def make_trn_masks(self, numpy_usrs, csr_mat):
method test_loss_epoch (line 202) | def test_loss_epoch(self, handler, dataset_id):
method test_epoch (line 234) | def test_epoch(self, handler, dataset_id):
method calc_recall_ndcg (line 263) | def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
method save_history (line 282) | def save_history(self):
method load_model (line 294) | def load_model(self):
FILE: node_classification/model.py
class FeedForwardLayer (line 14) | class FeedForwardLayer(nn.Module):
method __init__ (line 15) | def __init__(self, in_feat, out_feat, bias=True, act=None):
method forward (line 29) | def forward(self, embeds):
class TopoEncoder (line 34) | class TopoEncoder(nn.Module):
method __init__ (line 35) | def __init__(self):
method forward (line 40) | def forward(self, adj, embeds, normed=False):
class MLP (line 56) | class MLP(nn.Module):
method __init__ (line 57) | def __init__(self):
method forward (line 63) | def forward(self, embeds):
class GTLayer (line 68) | class GTLayer(nn.Module):
method __init__ (line 69) | def __init__(self):
method _pick_anchors (line 77) | def _pick_anchors(self, embeds):
method forward (line 82) | def forward(self, embeds):
class GraphTransformer (line 92) | class GraphTransformer(nn.Module):
method __init__ (line 93) | def __init__(self):
method forward (line 97) | def forward(self, embeds):
class Feat_Projector (line 102) | class Feat_Projector(nn.Module):
method __init__ (line 103) | def __init__(self, feats):
method svd_proj (line 117) | def svd_proj(self, feats):
method uniform_proj (line 128) | def uniform_proj(self, feats):
method random_proj (line 132) | def random_proj(self, feats):
method forward (line 136) | def forward(self):
class Adj_Projector (line 139) | class Adj_Projector(nn.Module):
method __init__ (line 140) | def __init__(self, adj):
method svd_proj (line 147) | def svd_proj(self, adj):
method forward (line 165) | def forward(self):
class Expert (line 168) | class Expert(nn.Module):
method __init__ (line 169) | def __init__(self):
method forward (line 179) | def forward(self, projectors, pck_nodes=None):
method pred_norm (line 186) | def pred_norm(self, pos_preds, neg_preds):
method cal_loss (line 195) | def cal_loss(self, batch_data, projectors):
method pred_for_test (line 225) | def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed...
method pred_for_node_test (line 252) | def pred_for_node_test(self, nodes, cand_size, feats, rerun_embed=True):
method attempt (line 262) | def attempt(self, topo_embeds, dataset):
class AnyGraph (line 288) | class AnyGraph(nn.Module):
method __init__ (line 289) | def __init__(self):
method one_graph_one_expert (line 294) | def one_graph_one_expert(self, handlers, log_assignment=False):
method store_history_assignment (line 301) | def store_history_assignment(self, assignment):
method assign_experts (line 304) | def assign_experts(self, handlers, reca=True, log_assignment=False):
method summon (line 345) | def summon(self, dataset_id):
method summon_opt (line 348) | def summon_opt(self, dataset_id):
FILE: node_classification/params.py
function parse_args (line 3) | def parse_args():
FILE: params.py
function parse_args (line 3) | def parse_args():
Condensed preview — 15 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (110K chars).
[
{
"path": ".gitignore",
"chars": 35,
"preview": "extract_code_structure.py\n.DS_Store"
},
{
"path": "Models/README.md",
"chars": 188,
"preview": "Download the pre-trained AnyGraph models from <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efm"
},
{
"path": "README.md",
"chars": 9527,
"preview": "<h1 align='center'>AnyGraph: Graph Foundation Model in the Wild</h1>\n\n<div align='center'>\n<a href='https://arxiv.org/pd"
},
{
"path": "Utils/TimeLogger.py",
"chars": 476,
"preview": "import datetime\n\nlogmsg = ''\ntimemark = dict()\nsaveDefault = False\ndef log(msg, save=None, oneline=False):\n\tglobal logms"
},
{
"path": "data_handler.py",
"chars": 13207,
"preview": "import pickle\nimport numpy as np\nfrom scipy.sparse import csr_matrix, coo_matrix, dok_matrix\nfrom params import args\nimp"
},
{
"path": "main.py",
"chars": 15122,
"preview": "import torch as t\nfrom torch import nn\nimport Utils.TimeLogger as logger\nfrom Utils.TimeLogger import log\nfrom params im"
},
{
"path": "model.py",
"chars": 14298,
"preview": "import torch as t\nfrom torch import nn\nimport torch.nn.functional as F\nfrom params import args\nimport numpy as np\nfrom U"
},
{
"path": "node_classification/Utils/TimeLogger.py",
"chars": 476,
"preview": "import datetime\n\nlogmsg = ''\ntimemark = dict()\nsaveDefault = False\ndef log(msg, save=None, oneline=False):\n\tglobal logms"
},
{
"path": "node_classification/data_handler.py",
"chars": 12729,
"preview": "import pickle\nimport numpy as np\nfrom scipy.sparse import csr_matrix, coo_matrix, dok_matrix\nfrom params import args\nimp"
},
{
"path": "node_classification/main.py",
"chars": 17196,
"preview": "import torch as t\nfrom torch import nn\nimport Utils.TimeLogger as logger\nfrom Utils.TimeLogger import log\nfrom params im"
},
{
"path": "node_classification/model.py",
"chars": 15751,
"preview": "import torch as t\nfrom torch import nn\nimport torch.nn.functional as F\nfrom params import args\nimport numpy as np\nfrom U"
},
{
"path": "node_classification/params.py",
"chars": 3986,
"preview": "import argparse\n\ndef parse_args():\n parser = argparse.ArgumentParser(description='Model Parameters')\n parser.add_a"
},
{
"path": "params.py",
"chars": 3714,
"preview": "import argparse\n\ndef parse_args():\n parser = argparse.ArgumentParser(description='Model Parameters')\n parser.add_a"
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the HKUDS/AnyGraph GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 15 files (104.2 KB), approximately 27.4k tokens, and a symbol index with 166 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.