Full Code of HKUDS/AnyGraph for AI

main c2c5bfe103c4 cached
15 files
104.2 KB
27.4k tokens
166 symbols
1 requests
Download .txt
Repository: HKUDS/AnyGraph
Branch: main
Commit: c2c5bfe103c4
Files: 15
Total size: 104.2 KB

Directory structure:
gitextract_t0oakqy3/

├── .gitignore
├── History/
│   ├── pretrain_link1.his
│   └── pretrain_link2.his
├── Models/
│   └── README.md
├── README.md
├── Utils/
│   └── TimeLogger.py
├── data_handler.py
├── main.py
├── model.py
├── node_classification/
│   ├── Utils/
│   │   └── TimeLogger.py
│   ├── data_handler.py
│   ├── main.py
│   ├── model.py
│   └── params.py
└── params.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
extract_code_structure.py
.DS_Store

================================================
FILE: Models/README.md
================================================
Download the pre-trained AnyGraph models from <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efmm5TJm0B5EnmYzTqg8GWEB1loKzeIR5tcr3hPIOJDXXA?e=2wMgZC'>this link</a>.

================================================
FILE: README.md
================================================
<h1 align='center'>AnyGraph: Graph Foundation Model in the Wild</h1>

<div align='center'>
<a href='https://arxiv.org/pdf/2408.10700'><img src='https://img.shields.io/badge/Paper-PDF-green'></a>
<!-- <a href=''><img src='https://img.shields.io/badge/公众号-blue' /></a> -->
<!-- <a href=''><img src='https://img.shields.io/badge/CSDN-orange' /></a> -->
<img src="https://badges.pufler.dev/visits/hkuds/anygraph?style=flat-square&logo=github">
<img src='https://img.shields.io/github/stars/hkuds/anygraph?color=green&style=social' />

<a href='https://akaxlh.github.io/'>Lianghao Xia</a> and <a href='https://sites.google.com/view/chaoh/group-join-us'>Chao Huang</a>

**Introducing AnyGraph, a graph foundation model designed for zero-shot predictions across domains.**

<img src='imgs/article cover.png' />

</div>

**Objectives of AnyGraph:**

- **Structure Heterogeneity**: Addressing distribution shift in graph structural information.
- **Feature Heterogeneity**: Handling diverse feature representation spaces across graph datasets.
- **Fast Adaptation**: Efficiently adapting the model to new graph domains.
- **Scaling Law Emergence**: Performance scales with the amount of data and model parameters.

<br>

**Key Features of AnyGraph:**

- **Graph Mixture-of-Experts (MoE)**: Effectively addresses cross-domain heterogeneity using an array of expert models.
- **Lightweight Graph Expert Routing Mechanism**: Enables swift adaptation to new datasets and domains.
- **Adaptive and Efficient Graph Experts**: Custom-designed to handle graphs with a wide range of structural patterns and feature spaces.
- **Extensively Trained and Tested**: Exhibits strong generalizability over 38 diverse graph datasets, showcasing scaling laws and emergent capabilities.

<img src='imgs/framework_final.jpeg' />

## News
- [x] 2024/09/18 Fix minor bugs on loading and saving path, class names, etc.
- [x] 2024/09/11 Datasets are available now!
- [x] 2024/09/11 Pre-trained models are available on huggingface now!
- [x] 2024/08/20 Paper is released.
- [x] 2024/08/20 Code is released.

## Environment Setup
Download the data files at <a href='https://huggingface.co/datasets/hkuds/AnyGraph_datasets'>this link</a>. And fill in your own directories for data storage at function `get_data_files(self)` of class `DataHandler` in the file `data_handler.py`.

Download the pre-trained AnyGraph models at <a href='https://huggingface.co/hkuds/AnyGraph/'>hugging face</a> or <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efmm5TJm0B5EnmYzTqg8GWEB1loKzeIR5tcr3hPIOJDXXA?e=2wMgZC'>one drive</a>, and put it into `Models/`.

**Packages**: Our experiments were conducted with the following package versions:
* python==3.10.13
* torch==1.13.0
* numpy==1.23.4
* scipy==1.9.3

**Device Requirements**: The training and testing of AnyGraph requires only one GPU with 24G memory (e.g. 3090, 4090). Using larger input graphs may require devices with larger memory.

## Code Structure
Here is a brief overview of the code structures. The explanations for each directory are enclosed in quotes (##...##).
```
./
│   ├── .gitignore
│   ├── README.md
│   ├── data_handler.py ## load and process data
│   ├── model.py ## implementation for the model
│   ├── params.py ## hyperparameters
│   └── main.py ## main file for pretraining and link prediction
│   ├── History/ ## training and testing logs
│   │   ├── pretrain_link1.his
│   │   └── pretrain_link2.his
│   ├── Models/ ## pre-trained models
│   │   └── README.md
│   ├── Utils/ ## utility function
│   │   └── TimeLogger.py
│   ├── imgs/ ## images used in readme
│   │   ├── ablation.png
│   │   ├── article cover.png
│   │   ├── datasets.png
│   │   ├── framework_final.jpeg
│   │   ├── framework_final.pdf
│   │   ├── framework_final.png
│   │   ├── overall_performance2.png
│   │   ├── routing.png
│   │   ├── scaling_law.png
│   │   ├── training_time.png
│   │   ├── tuning_steps.png
│   │   └── overall_performance1.png
│   ├── node_classification/ ## test code for node classification
│   │   ├── data_handler.py
│   │   ├── model.py
│   │   ├── params.py
│   │   └── main.py
│   │   ├── Utils/
│   │   │   └── TimeLogger.py
```

## Usage
To reproduce the test performance reported in the paper, run the following command lines:
```
# Test on Link2 and Link1 data, respectively
python main.py --load pretrain_link1 --epoch 0 --dataset link2 
python main.py --load pretrain_link2 --epoch 0 --dataset link1

# Test on the Ecommerce datasets in the Link2 and Link1 group, respectively.
# Testing on Academic and Others datasets are conducted similarily.
python main.py --load pretrain_link1 --epoch 0 --dataset ecommerce_in_link2
python main.py --load pretrain_link2 --epoch 0 --dataset ecommerce_in_link1

# Test the performance for node classification datasets
cd ./node_classification
python main.py --load pretrain_link2 --epoch 0 --dataset node
```

To re-train the two models by yourself, run:
```
python main.py --dataset link2+link1 --save pretrain_link2
python main.py --dataset link1+link2 --save pretrain_link1
```

## Datasets

<img src='imgs/datasets.png' />

The statistics for the experimental datasets are presented in the table above. We categorize them into distinct groups as below. Note that Link1 and Link2 include datasets from different sources, and the datasets do not share the same feature spaces. This separation ensures a robust evaluation of true zero-shot performance in graph prediction tasks.

| Group | Included Datasets |
| ----- | ----- |
| Link1 | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, P2P-Gnutella06, Soc-Epinions1, Email-Enron |
| Link2 | Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla, Arxiv, Arxiv-t, Cora, CS, OGB-Collab, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA |
| Ecommerce | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla |
| Academic | Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, Arxiv, Arxiv-t, Cora, CS, OGB-Collab |
| Others | P2P-Gnutella06, Soc-Epinions1, Email-Enron, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA |
| Node | Cora, Arxiv, Pubmed, Home, Tech |


## Experiments

### Model Pre-Training Curves
We present the training logs with respect to epochs below. Each figure contains two curves, each corresponding to two instances of repeated pre-training.

- pretrain_link1

<img src='imgs/link1_loss_curve.png' width=32%/>&nbsp;
<img src='imgs/link1_fullshot_ndcg_curve.png' width=32%/>&nbsp;
<img src='imgs/link1_zeroshot_ndcg_curve.png' width=32%/>

- pretrain_link2

<img src='imgs/link2_loss_curve.png' width=32%/>&nbsp;
<img src='imgs/link2_fullshot_ndcg_curve.png' width=32%/>&nbsp;
<img src='imgs/link2_zeroshot_ndcg_curve.png' width=32%/>

### Overall Performance Comparison

- Comparing to few-shot end2end models and pre-training and fine-tuning methods.
![](imgs/overall_performance1.png)

- Comparing to zero-shot graph foundation models.
<img src='imgs/overall_performance2.png' width=60%/>

### Scaling Law of AnyGraph

We explore the scaling law of AnyGraph by evaluating 1) model performance v.s. the number of model parameters, and 2) model performance v.s. the number of training samples. 

Below shows the evaluation results on 
- all datasets across domains (a)
- academic datasets (b)
- ecommerce datasets (c)
- other datasets (d) 

In each subfigure, we show 
- zero-shot performance on unseen datasets w.r.t. the amount of model parameters (left)
- full-shot performance on training datasets w.r.t. the amount of model parameters (middle)
- zero-shot performance w.r.t. the amount of training data (right)

![](imgs/scaling_law.png)

The outcome outlines the following key observations: (see Sec. 4.3 for details)
- Generalizability of AnyGraph Follows the Scaling Law.
- Emergent Abilities of AnyGraph.
- Insufficient Training Data May Bring Bias.

### Ablation Study

The ablation study investigates the impact of the following modules:
- The overall MoE architecture
- Frequency regularization in the expert routing mechanism
- Graph augmentation in the learning process
- The utilization of (heterogeneous) node features from different datasets
  
<img src='imgs/ablation.png' width=60% />


### Expert Routing Mechanism
We visualize the competence scores between datasets and experts, given by the routing algorithm of AnyGraph. 

The resulting scores below demonstrates the underlying relatedness between different datasets, thus demonstrating the intuitive effectiveness of the routing mechanism. (see Sec. 4.5 for details)

<img src='imgs/routing.png' width=60% />


### Fast Adaptation of AnyGraph

We study the fast adaptation abilities of AnyGraph from two aspects:
- When fine-tuned on unseen datasets, AnyGraph achieves better performance with less training steps. (Fig. 6 below)
- The training time of AnyGraph is comparative to that of other methods. (Table 3 below)

<img src='imgs/tuning_steps.png' width=60% />

<img src='imgs/training_time.png' width=60% />

## Citation

If you find our work useful, please consider citing our paper:
```
@article{xia2024anygraph,
  title={AnyGraph: Graph Foundation Model in the Wild},
  author={Xia, Lianghao and Huang, Chao},
  journal={arXiv preprint arXiv:2408.10700},
  year={2024}
}
```


================================================
FILE: Utils/TimeLogger.py
================================================
import datetime

logmsg = ''
timemark = dict()
saveDefault = False
def log(msg, save=None, oneline=False):
	global logmsg
	global saveDefault
	time = datetime.datetime.now()
	tem = '%s: %s' % (time, msg)
	if save != None:
		if save:
			logmsg += tem + '\n'
	elif saveDefault:
		logmsg += tem + '\n'
	if oneline:
		print(tem, end='\r')
	else:
		print(tem)

def marktime(marker):
	global timemark
	timemark[marker] = datetime.datetime.now()


if __name__ == '__main__':
	log('')

================================================
FILE: data_handler.py
================================================
import pickle
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
from params import args
import scipy.sparse as sp
from Utils.TimeLogger import log
import torch as t
import torch.utils.data as data
import torch_geometric.transforms as T
from model import Feat_Projector, Adj_Projector, TopoEncoder
import os

class MultiDataHandler:
    def __init__(self, trn_datasets, tst_datasets_group):
        all_datasets = trn_datasets
        all_tst_datasets = []
        for tst_datasets in tst_datasets_group:
            all_datasets = all_datasets + tst_datasets
            all_tst_datasets += tst_datasets
        all_datasets = list(set(all_datasets))
        all_datasets.sort()
        self.trn_handlers = []
        self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))]
        for data_name in all_datasets:
            trn_flag = data_name in trn_datasets
            tst_flag = data_name in all_tst_datasets
            handler = DataHandler(data_name)
            if trn_flag:
                self.trn_handlers.append(handler)
            if tst_flag:
                for i in range(len(tst_datasets_group)):
                    if data_name in tst_datasets_group[i]:
                        self.tst_handlers_group[i].append(handler)
        self.make_joint_trn_loader()
    
    def make_joint_trn_loader(self):
        loader_datasets = []
        for trn_handler in self.trn_handlers:
            tem_dataset = trn_handler.trn_loader.dataset
            loader_datasets.append(tem_dataset)
        joint_dataset = JointTrnData(loader_datasets)
        self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0)
    
    def remake_initial_projections(self):
        for i in range(len(self.trn_handlers)):
            trn_handler = self.trn_handlers[i]
            trn_handler.make_projectors()

class DataHandler:
    def __init__(self, data_name):
        self.data_name = data_name
        self.get_data_files()
        log(f'Loading dataset {data_name}')
        self.topo_encoder = TopoEncoder()
        self.load_data()
    
    def get_data_files(self):
        predir = f'/home/user_name/data/zero-shot datasets/{self.data_name}/'
        if os.path.exists(predir + 'feats.pkl'):
            self.feat_file = predir + 'feats.pkl'
        else:
            self.feat_file = None
        self.trnfile = predir + 'trn_mat.pkl'
        self.tstfile = predir + 'tst_mat.pkl'
        self.fewshotfile = predir + 'partial_mat_{shot}.pkl'.format(shot=args.ratio_fewshot)
        self.valfile = predir + 'val_mat.pkl'

    def load_one_file(self, filename):
        with open(filename, 'rb') as fs:
            ret = (pickle.load(fs) != 0).astype(np.float32)
        if type(ret) != coo_matrix:
            ret = sp.coo_matrix(ret)
        return ret
    
    def load_feats(self, filename):
        try:
            with open(filename, 'rb') as fs:
                feats = pickle.load(fs)
        except Exception as e:
            print(filename + str(e))
            exit()
        return feats

    def normalize_adj(self, mat, log=False):
        degree = np.array(mat.sum(axis=-1))
        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
        if mat.shape[0] == mat.shape[1]:
            return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()
        else:
            tem = d_inv_sqrt_mat.dot(mat)
            col_degree = np.array(mat.sum(axis=0))
            d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])
            d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
            d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
            return tem.dot(d_inv_sqrt_mat).tocoo()
    
    def unique_numpy(self, row, col):
        hash_vals = row * args.node_num + col
        hash_vals = np.unique(hash_vals).astype(np.int64)
        col = hash_vals % args.node_num
        row = (hash_vals - col).astype(np.int64) // args.node_num
        return row, col

    def make_torch_adj(self, mat, unidirectional_for_asym=False):
        if mat.shape[0] == mat.shape[1]:
            _row = mat.row
            _col = mat.col
            row = np.concatenate([_row, _col]).astype(np.int64)
            col = np.concatenate([_col, _row]).astype(np.int64)
            row, col = self.unique_numpy(row, col)
            data = np.ones_like(row)
            mat = coo_matrix((data, (row, col)), mat.shape)
            if args.selfloop == 1:
                mat = (mat + sp.eye(mat.shape[0])) * 1.0
            normed_asym_mat = self.normalize_adj(mat)
            row = t.from_numpy(normed_asym_mat.row).long()
            col = t.from_numpy(normed_asym_mat.col).long()
            idxs = t.stack([row, col], dim=0)
            vals = t.from_numpy(normed_asym_mat.data).float()
            shape = t.Size(normed_asym_mat.shape)
            asym_adj = t.sparse.FloatTensor(idxs, vals, shape)
            return asym_adj
        elif unidirectional_for_asym:
            mat = (mat != 0) * 1.0
            mat = self.normalize_adj(mat, log=True)
            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
            vals = t.from_numpy(mat.data.astype(np.float32))
            shape = t.Size(mat.shape)
            return t.sparse.FloatTensor(idxs, vals, shape)
        else:
            # make ui adj
            a = sp.csr_matrix((args.user_num, args.user_num))
            b = sp.csr_matrix((args.item_num, args.item_num))
            mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
            mat = (mat != 0) * 1.0
            if args.selfloop == 1:
                mat = (mat + sp.eye(mat.shape[0])) * 1.0
            mat = self.normalize_adj(mat)

            # make cuda tensor
            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
            vals = t.from_numpy(mat.data.astype(np.float32))
            shape = t.Size(mat.shape)
            return t.sparse.FloatTensor(idxs, vals, shape)

    def load_data(self):
        tst_mat = self.load_one_file(self.tstfile)
        val_mat = self.load_one_file(self.valfile)
        trn_mat = self.load_one_file(self.trnfile)
        fewshot_mat = self.load_one_file(self.fewshotfile)
        if self.feat_file is not None:
            self.feats = t.from_numpy(self.load_feats(self.feat_file)).float()
            self.feats = self.feats
            args.featdim = self.feats.shape[1]
        else:
            self.feats = None
            args.featdim = args.latdim

        if trn_mat.shape[0] != trn_mat.shape[1]:
            args.user_num, args.item_num = trn_mat.shape
            args.node_num = args.user_num + args.item_num
            print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz))
        else:
            args.node_num = trn_mat.shape[0]
            print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+val_mat.nnz+tst_mat.nnz))
        if args.tst_mode == 'tst':
            tst_data = TstData(tst_mat, trn_mat)
            self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
            self.tst_input_adj = self.make_torch_adj(trn_mat)
        elif args.tst_mode == 'val':
            tst_data = TstData(val_mat, trn_mat)
            self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
            self.tst_input_adj = self.make_torch_adj(fewshot_mat)
        else:
            raise Exception('Specify proper test mode')

        if args.trn_mode == 'fewshot':
            self.trn_mat = fewshot_mat
            trn_data = TrnData(self.trn_mat)
            self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)
            self.trn_input_adj = self.make_torch_adj(fewshot_mat)
            if args.tst_mode == 'val':
                self.trn_input_adj = self.tst_input_adj
            else:
                self.trn_input_adj = self.make_torch_adj(fewshot_mat)
        elif args.trn_mode == 'train-all':
            self.trn_mat = trn_mat
            trn_data = TrnData(self.trn_mat)
            self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)
            if args.tst_mode == 'tst':
                self.trn_input_adj = self.tst_input_adj
            else:
                self.trn_input_adj = self.make_torch_adj(trn_mat)
        else:
            raise Exception('Specify proper train mode')   

        if self.trn_mat.shape[0] == self.trn_mat.shape[1]:
            self.asym_adj = self.trn_input_adj
        else:
            self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True)
        self.make_projectors()
        self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps)
        self.ratio_500_all = 500 / len(self.trn_loader)
    
    def make_projectors(self):
        with t.no_grad():
            projectors = []
            if args.proj_method == 'adj_svd' or args.proj_method == 'both':
                tem = self.asym_adj.to(args.devices[0])
                projectors = [Adj_Projector(tem)]
            if self.feats is not None and args.proj_method != 'adj_svd':
                tem = self.feats.to(args.devices[0])
                projectors.append(Feat_Projector(tem))
            assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot'
            feats = projectors[0]()
            if len(projectors) == 2:
                feats2 = projectors[1]()
                feats = feats + feats2

            try:
                self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu()
            except Exception:
                print(f'{self.data_name} memory overflow')
                mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True)
                tem_adj = self.trn_input_adj.to(args.devices[0])
                mem_cache = 256
                projectors_list = []
                for i in range(feats.shape[1] // mem_cache):
                    st, ed = i * mem_cache, (i + 1) * mem_cache
                    tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8)
                    tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu()
                    projectors_list.append(tem_feats)
                self.projectors = t.concat(projectors_list, dim=-1)
            t.cuda.empty_cache()

class TstData(data.Dataset):
    def __init__(self, coomat, trn_mat):
        self.csrmat = (trn_mat.tocsr() != 0) * 1.0
        tstLocs = [None] * coomat.shape[0]
        tst_nodes = set()
        for i in range(len(coomat.data)):
            row = coomat.row[i]
            col = coomat.col[i]
            if tstLocs[row] is None:
                tstLocs[row] = list()
            tstLocs[row].append(col)
            tst_nodes.add(row)
        tst_nodes = np.array(list(tst_nodes))
        self.tst_nodes = tst_nodes
        self.tstLocs = tstLocs

    def __len__(self):
        return len(self.tst_nodes)

    def __getitem__(self, idx):
        return self.tst_nodes[idx]

class TrnData(data.Dataset):
    def __init__(self, coomat):
        self.ancs, self.poss = coomat.row, coomat.col
        self.negs = np.zeros(len(self.ancs)).astype(np.int32)
        self.cand_num = coomat.shape[1]
        self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0]
        self.poss = coomat.col + self.neg_shift
        self.neg_sampling()
    
    def neg_sampling(self):
        self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0])

    def __len__(self):
        return len(self.ancs)
    
    def __getitem__(self, idx):
        return self.ancs[idx], self.poss[idx] , self.negs[idx]

class JointTrnData(data.Dataset):
    def __init__(self, dataset_list):
        self.batch_dataset_ids = []
        self.batch_st_ed_list = []
        self.dataset_list = dataset_list
        for dataset_id, dataset in enumerate(dataset_list):
            samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0)
            for j in range(samp_num):
                self.batch_dataset_ids.append(dataset_id)
                st = j * args.batch
                ed = min((j + 1) * args.batch, len(dataset))
                self.batch_st_ed_list.append((st, ed))
    
    def neg_sampling(self):
        for dataset in self.dataset_list:
            dataset.neg_sampling()

    def __len__(self):
        return len(self.batch_dataset_ids)
    
    def __getitem__(self, idx):
        st, ed = self.batch_st_ed_list[idx]
        dataset_id = self.batch_dataset_ids[idx]
        return *self.dataset_list[dataset_id][st: ed], dataset_id

================================================
FILE: main.py
================================================
import torch as t
from torch import nn
import Utils.TimeLogger as logger
from Utils.TimeLogger import log
from params import args
from model import Expert, Feat_Projector, Adj_Projector, AnyGraph
from data_handler import MultiDataHandler, DataHandler
import numpy as np
import pickle
import os
import setproctitle
import time

class Exp:
    def __init__(self, multi_handler):
        self.multi_handler = multi_handler
        print(list(map(lambda x: x.data_name, multi_handler.trn_handlers)))
        for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group):
            print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers)))
        self.metrics = dict()
        trn_mets = ['Loss', 'preLoss']
        tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss']
        mets = trn_mets + tst_mets
        for met in mets:
            if met in trn_mets:
                self.metrics['Train' + met] = list()
            if met in tst_mets:
                for i in range(len(self.multi_handler.tst_handlers_group)):
                    self.metrics['Test' + str(i) + met] = list()
        
    def make_print(self, name, ep, reses, save, data_name=None):
        if data_name is None:
            ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
        else:
            ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)
        for metric in reses:
            val = reses[metric]
            ret += '%s = %.4f, ' % (metric, val)
            tem = name + metric if data_name is None else name + data_name + metric
            if save and tem in self.metrics:
                self.metrics[tem].append(val)
        ret = ret[:-2] + '      '
        return ret
    
    def run(self):
        self.prepare_model()
        log('Model Prepared')
        stloc = 0
        if args.load_model != None:
            self.load_model()
            stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)
        best_ndcg, best_ep = 0, -1
        for ep in range(stloc, args.epoch):
            tst_flag = (ep % args.tst_epoch == 0)
            start_time = time.time()
            self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True)
            reses = self.train_epoch()
            log(self.make_print('Train', ep, reses, tst_flag))
            self.multi_handler.remake_initial_projections()
            end_time = time.time()
            print(f'NOTICE: {end_time-start_time}')
            if tst_flag:
                for handler_group_id in range(len(self.multi_handler.tst_handlers_group)):
                    tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id]
                    self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)
                    recall, ndcg, tstnum = 0, 0, 0
                    for i, handler in enumerate(tst_handlers):
                        reses = self.test_epoch(handler, i)
                        # log(self.make_print(f'{handler.data_name}', ep, reses, False))
                        recall += reses['Recall'] * reses['tstNum']
                        ndcg += reses['NDCG'] * reses['tstNum']
                        tstnum += reses['tstNum']
                    reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum}
                    log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag))

                    if reses['NDCG'] > best_ndcg:
                        best_ndcg = reses['NDCG']
                        best_ep = ep
                self.save_history()
            print()

        for test_group_id in range(len(self.multi_handler.tst_handlers_group)):
            repeat_times = 5
            overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times)
            overall_tstnum = 0
            tst_handlers = self.multi_handler.tst_handlers_group[test_group_id]
            for i, handler in enumerate(tst_handlers):
                for topk in [args.topk]:
                    args.topk = topk
                    mets = dict()
                    for _ in range(repeat_times):
                        handler.make_projectors()
                        self.model.assign_experts([handler], reca=False, log_assignment=False)
                        reses = self.test_epoch(handler, 0)
                        for met in reses:
                            if met not in mets:
                                mets[met] = []
                            mets[met].append(reses[met])
                    tstnum = reses['tstNum']
                    tot_reses = dict()
                    for met in reses:
                        tem_arr = np.array(mets[met])
                        tot_reses[met + '_std'] = tem_arr.std()
                        tot_reses[met + '_mean'] = tem_arr.mean()
                    if topk == args.topk:
                        overall_recall += np.array(mets['Recall']) * tstnum
                        overall_ndcg += np.array(mets['NDCG']) * tstnum
                        overall_tstnum += tstnum
                    log(self.make_print(f'Test Top-{topk}', args.epoch, tot_reses, False, handler.data_name))
            overall_recall /= overall_tstnum
            overall_ndcg /= overall_tstnum
            overall_res = dict()
            overall_res['Recall_mean'] = overall_recall.mean()
            overall_res['Recall_std'] = overall_recall.std()
            overall_res['NDCG_mean'] = overall_ndcg.mean()
            overall_res['NDCG_std'] = overall_ndcg.std()
            log(self.make_print('Overall Test', args.epoch, overall_res, False))
        self.save_history()

    def print_model_size(self):
        total_params = 0
        trainable_params = 0
        non_trainable_params = 0
        for param in self.model.parameters():
            tem = np.prod(param.size())
            total_params += tem
            if param.requires_grad:
                trainable_params += tem
            else:
                non_trainable_params += tem
        print(f'Total params: {total_params/1e6}')
        print(f'Trainable params: {trainable_params/1e6}')
        print(f'Non-trainable params: {non_trainable_params/1e6}')

    def prepare_model(self):
        self.model = AnyGraph()
        t.cuda.empty_cache()
        self.print_model_size()

    def train_epoch(self):
        self.model.train()
        trn_loader = self.multi_handler.joint_trn_loader
        trn_loader.dataset.neg_sampling()
        ep_loss, ep_preloss, ep_regloss = 0, 0, 0
        steps = len(trn_loader)
        tot_samp_num = 0
        counter = [0] * len(self.multi_handler.trn_handlers)
        reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers)))
        for i, batch_data in enumerate(trn_loader):
            if args.epoch_max_step > 0 and i >= args.epoch_max_step:
                break
            ancs, poss, negs, dataset_id = batch_data
            ancs = ancs[0].long()
            poss = poss[0].long()
            negs = negs[0].long()
            dataset_id = dataset_id[0].long()
            tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all
            if tem_bar < 1.0 and np.random.uniform() > tem_bar:
                steps -= 1
                continue

            expert = self.model.summon(dataset_id)
            opt = self.model.summon_opt(dataset_id)
            feats = self.multi_handler.trn_handlers[dataset_id].projectors
            loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)
            opt.zero_grad()
            loss.backward()
            opt.step()

            sample_num = ancs.shape[0]
            tot_samp_num += sample_num
            ep_loss += loss.item() * sample_num
            ep_preloss += loss_dict['preloss'].item() * sample_num
            ep_regloss += loss_dict['regloss'].item()
            log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f        ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)

            counter[dataset_id] += 1
            if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0:
                self.multi_handler.trn_handlers[dataset_id].make_projectors()
            if (i + 1) % reassign_steps == 0:
                self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False)
        ret = dict()
        ret['Loss'] = ep_loss / tot_samp_num
        ret['preLoss'] = ep_preloss / tot_samp_num
        ret['regLoss'] = ep_regloss / steps
        t.cuda.empty_cache()
        return ret
    
    def make_trn_masks(self, numpy_usrs, csr_mat):
        trn_masks = csr_mat[numpy_usrs].tocoo()
        cand_size = trn_masks.shape[1]
        trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long()
        return trn_masks, cand_size

    def test_epoch(self, handler, dataset_id):
        with t.no_grad():
            tst_loader = handler.tst_loader
            self.model.eval()
            expert = self.model.summon(dataset_id)
            ep_recall, ep_ndcg = 0, 0
            ep_tstnum = len(tst_loader.dataset)
            steps = max(ep_tstnum // args.tst_batch, 1)
            for i, batch_data in enumerate(tst_loader):
                if args.tst_steps != -1 and i > args.tst_steps:
                    break

                usrs = batch_data.long()
                trn_masks, cand_size = self.make_trn_masks(batch_data.numpy(), tst_loader.dataset.csrmat)
                feats = handler.projectors
                all_preds = expert.pred_for_test((usrs, trn_masks), cand_size, feats, rerun_embed=False if i!=0 else True)
                _, top_locs = t.topk(all_preds, args.topk)
                top_locs = top_locs.cpu().numpy()
                recall, ndcg = self.calc_recall_ndcg(top_locs, tst_loader.dataset.tstLocs, usrs)
                ep_recall += recall
                ep_ndcg += ndcg
                log('Steps %d/%d: recall = %.2f, ndcg = %.2f          ' % (i, steps, recall, ndcg), save=False, oneline=True)
        ret = dict()
        if args.tst_steps != -1:
            ep_tstnum = args.tst_steps * args.tst_batch
        ret['Recall'] = ep_recall / ep_tstnum
        ret['NDCG'] = ep_ndcg / ep_tstnum
        ret['tstNum'] = ep_tstnum
        t.cuda.empty_cache()
        return ret
    
    def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
        assert topLocs.shape[0] == len(batIds)
        allRecall = allNdcg = 0
        for i in range(len(batIds)):
            temTopLocs = list(topLocs[i])
            temTstLocs = tstLocs[batIds[i]]
            tstNum = len(temTstLocs)
            maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])
            recall = dcg = 0
            for val in temTstLocs:
                if val in temTopLocs:
                    recall += 1
                    dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
            recall = recall / tstNum
            ndcg = dcg / maxDcg
            allRecall += recall
            allNdcg += ndcg
        return allRecall, allNdcg
    
    def save_history(self):
        if args.epoch == 0:
            return
        with open('./History/' + args.save_path + '.his', 'wb') as fs:
            pickle.dump(self.metrics, fs)

        content = {
            'model': self.model,
        }
        t.save(content, './Models/' + args.save_path + '.mod')
        log('Model Saved: %s' % args.save_path)

    def load_model(self):
        ckp = t.load('./Models/' + args.load_model + '.mod')
        self.model = ckp['model']
        self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)

        with open('./History/' + args.load_model + '.his', 'rb') as fs:
            self.metrics = pickle.load(fs)
        log('Model Loaded')

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if len(args.gpu.split(',')) == 2:
        args.devices = ['cuda:0', 'cuda:1']
    elif len(args.gpu.split(',')) > 2:
        raise Exception('Devices should be less than 2')
    else:
        args.devices = ['cuda:0', 'cuda:0']
    logger.saveDefault = True
    setproctitle.setproctitle('AnyGraph')

    log('Start')

    
    datasets = dict()
    datasets['all'] = [
        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'
    ]
    datasets['ecommerce'] = [
        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech'
    ]
    datasets['academic'] = [
        'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab'
    ]
    datasets['others'] = [
        'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'
    ]
    datasets['link1'] = [
        'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron',
    ]
    datasets['link2'] = [
        'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA',
    ]

    if args.dataset_setting in datasets.keys():
        trn_datasets = tst_datasets = datasets[args.dataset_setting]
    elif args.dataset_setting in datasets['all']:
        trn_datasets = tst_datasets = [args.dataset_setting]
    elif '+' in args.dataset_setting:
        idx = args.dataset_setting.index('+')
        trn_datasets = datasets[args.dataset_setting[:idx]]
        tst_datasets = datasets[args.dataset_setting[idx+1:]]
    elif '_in_' in args.dataset_setting:
        idx = args.dataset_setting.index('_in_')
        tst_datasets_1 = datasets[args.dataset_setting[:idx]]
        tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]]
        tst_datasets = []
        for data in tst_datasets_1:
            if data in tst_datasets_2:
                tst_datasets.append(data)
        trn_datasets = tst_datasets

    if '+' not in args.dataset_setting:
        # No zero-shot prediction test
        handler = MultiDataHandler(trn_datasets, [tst_datasets])
    else:
        handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets])
    log('Load Data')

    exp = Exp(handler)
    exp.run()
    print(args.load_model, args.dataset_setting)


================================================
FILE: model.py
================================================
import torch as t
from torch import nn
import torch.nn.functional as F
from params import args
import numpy as np
from Utils.TimeLogger import log
from torch.nn import MultiheadAttention
from time import time

init = nn.init.xavier_uniform_
uniformInit = nn.init.uniform_

class FeedForwardLayer(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True, act=None):
        super(FeedForwardLayer, self).__init__()
        self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)
        if act == 'identity' or act is None:
            self.act = None
        elif act == 'leaky':
            self.act = nn.LeakyReLU(negative_slope=args.leaky)
        elif act == 'relu':
            self.act = nn.ReLU()
        elif act == 'relu6':
            self.act = nn.ReLU6()
        else:
            raise Exception('Error')
    
    def forward(self, embeds):
        if self.act is None:
            return self.linear(embeds)
        return self.act(self.linear(embeds))

class TopoEncoder(nn.Module):
    def __init__(self):
        super(TopoEncoder, self).__init__()

        self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)

    def forward(self, adj, embeds, normed=False):
        with t.no_grad():
            if not normed:
                embeds = self.layer_norm(embeds)
            # embeds_list = []
            final_embeds = 0
            if args.gnn_layer == 0:
                final_embeds = embeds
                # embeds_list.append(embeds)
            for _ in range(args.gnn_layer):
                embeds = t.spmm(adj, embeds)
                final_embeds += embeds
                # embeds_list.append(embeds)
            embeds = final_embeds#sum(embeds_list)
        return embeds

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)])
        self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)])
        self.dropout = nn.Dropout(p=args.drop_rate)
    
    def forward(self, embeds):
        for i in range(args.fc_layer):
            embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds)
        return embeds

class GTLayer(nn.Module):
    def __init__(self):
        super(GTLayer, self).__init__()
        self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)
        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False
        self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
        self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
        self.fc_dropout = nn.Dropout(p=args.drop_rate)
    
    def _pick_anchors(self, embeds):
        perm = t.randperm(embeds.shape[0])
        anchors = perm[:args.anchor]
        return embeds[anchors]
    
    def forward(self, embeds):
        anchor_embeds = self._pick_anchors(embeds)
        _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)
        anchor_embeds = _anchor_embeds + anchor_embeds
        _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)
        embeds = self.layer_norm1(_embeds + embeds)
        _embeds = self.fc_dropout(self.dense_layers(embeds))
        embeds = (self.layer_norm2(_embeds + embeds))
        return embeds

class GraphTransformer(nn.Module):
    def __init__(self):
        super(GraphTransformer, self).__init__()
        self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])

    def forward(self, embeds):
        for i, layer in enumerate(self.gt_layers):
            embeds = layer(embeds) / args.scale_layer
        return embeds

class Feat_Projector(nn.Module):
    def __init__(self, feats):
        super(Feat_Projector, self).__init__()

        if args.proj_method == 'uniform':
            self.proj_embeds = self.uniform_proj(feats)
        elif args.proj_method == 'svd' or args.proj_method == 'both':
            self.proj_embeds = self.svd_proj(feats)
        elif args.proj_method == 'random':
            self.proj_embeds = self.random_proj(feats)
        elif args.proj_method == 'original':
            self.proj_embeds = feats
        self.proj_embeds = t.flip(self.proj_embeds, dims=[-1])
        self.proj_embeds = self.proj_embeds.detach()
    
    def svd_proj(self, feats):
        if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]:
            dim = min(feats.shape[0], feats.shape[1])
            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter)
            decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
            s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])])
        else:
            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter)
        decom_feats = decom_feats @ t.diag(t.sqrt(s))
        return decom_feats.cpu()
    
    def uniform_proj(self, feats):
        projection = init(t.empty(args.featdim, args.latdim))
        return feats @ projection
    
    def random_proj(self, feats):
        projection = init(t.empty(feats.shape[0], args.latdim))
        return projection
    
    def forward(self):
        return self.proj_embeds

class Adj_Projector(nn.Module):
    def __init__(self, adj):
        super(Adj_Projector, self).__init__()

        if args.proj_method == 'adj_svd' or args.proj_method == 'both':
            self.proj_embeds = self.svd_proj(adj)
        self.proj_embeds = self.proj_embeds.detach()
    
    def svd_proj(self, adj):
        q = args.latdim
        if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:
            dim = min(adj.shape[0], adj.shape[1])
            svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)
            svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
            svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
            s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])
        else:
            svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)
        svd_u = svd_u @ t.diag(t.sqrt(s))
        svd_v = svd_v @ t.diag(t.sqrt(s))
        if adj.shape[0] != adj.shape[1]:
            projection = t.concat([svd_u, svd_v], dim=0)
        else:
            projection = svd_u + svd_v
        return projection.cpu()
    
    def forward(self):
        return self.proj_embeds

class Expert(nn.Module):
    def __init__(self):
        super(Expert, self).__init__()
        
        self.topo_encoder = TopoEncoder().to(args.devices[0])
        if args.nn == 'mlp':
            self.trainable_nn = MLP().to(args.devices[1])
        else:
            self.trainable_nn = GraphTransformer().to(args.devices[1])
        self.trn_count = 1
    
    def forward(self, projectors, pck_nodes=None):
        embeds = projectors.to(args.devices[1])
        if pck_nodes is not None:
            embeds = embeds[pck_nodes]
        embeds = self.trainable_nn(embeds)
        return embeds

    def pred_norm(self, pos_preds, neg_preds):
        pos_preds_num = pos_preds.shape[0]
        neg_preds_shape = neg_preds.shape
        preds = t.concat([pos_preds, neg_preds.view(-1)])
        preds = preds - preds.max()
        pos_preds = preds[:pos_preds_num]
        neg_preds = preds[pos_preds_num:].view(neg_preds_shape)
        return pos_preds, neg_preds
    
    def cal_loss(self, batch_data, projectors):
        ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data))
        self.trn_count += ancs.shape[0]
        pck_nodes = t.concat([ancs, poss, negs])
        final_embeds = self.forward(projectors, pck_nodes)
        # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]
        anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3)
        if final_embeds.isinf().any() or final_embeds.isnan().any():
            raise Exception('Final embedding fails')
        
        if args.loss == 'ce':
            pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)
            if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():
                raise Exception('Preds fails')
            pos_loss = pos_preds
            neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()
            pre_loss = -(pos_loss - neg_loss).mean()
        elif args.loss == 'bpr':
            pos_preds = (anc_embeds * pos_embeds).sum(-1)
            neg_preds = (anc_embeds * neg_embeds).sum(-1)
            pos_loss, neg_loss = pos_preds, neg_preds
            pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean() 

        if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():
            raise Exception('NaN or Inf')

        reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))
        loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}
        return pre_loss + reg_loss, loss_dict
    
    def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True):
        ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data))
        if rerun_embed:
            try:
                final_embeds = self.forward(projectors)
            except Exception:
                final_embeds_list = []
                div = args.batch * 3
                temlen = projectors.shape[0] // div
                for i in range(temlen):
                    st, ed = div * i, div * (i + 1)
                    tem_projectors = projectors[st: ed, :]
                    final_embeds_list.append(self.forward(tem_projectors))
                if temlen * div < projectors.shape[0]:
                    tem_projectors = projectors[temlen*div:, :]
                    final_embeds_list.append(self.forward(tem_projectors))
                final_embeds = t.concat(final_embeds_list, dim=0)
            self.final_embeds = final_embeds
        final_embeds = self.final_embeds
        anc_embeds = final_embeds[ancs]
        cand_embeds = final_embeds[-cand_size:]

        mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size]))
        dense_mat = mask_mat.to_dense()
        all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8
        return all_preds

    def attempt(self, topo_embeds, dataset):
        final_embeds = self.trainable_nn(topo_embeds)
        rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs]))
        if rows.shape[0] > args.attempt_cache:
            random_perm = t.randperm(rows.shape[0], device=args.devices[0])
            pck_perm = random_perm[:args.attempt_cache]
            rows = rows[pck_perm]
            cols = cols[pck_perm]
            negs = negs[pck_perm]
        while True:
            try:
                row_embeds = final_embeds[rows]
                col_embeds = final_embeds[cols]
                neg_embeds = final_embeds[negs]
                score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item()
                break
            except Exception:
                args.attempt_cache = args.attempt_cache // 2
                random_perm = t.randperm(rows.shape[0], device=args.devices[0])
                pck_perm = random_perm[:args.attempt_cache]
                rows = rows[pck_perm]
                cols = cols[pck_perm]
                negs = negs[pck_perm]
        t.cuda.empty_cache()
        return score

class AnyGraph(nn.Module):
    def __init__(self):
        super(AnyGraph, self).__init__()
        self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda()
        self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts))
        
    def assign_experts(self, handlers, reca=True, log_assignment=False):
        if args.expert_num == 1:
            self.assignment = [0] * len(handlers)
            return
        try:
            expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts)))
            expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2
        except Exception:
            expert_scores = np.ones(len(self.experts))
        with t.no_grad():
            assignment = [list() for i in range(len(handlers))]
            for dataset_id, handler in enumerate(handlers):
                topo_embeds = handler.projectors.to(args.devices[1])
                for expert_id, expert in enumerate(self.experts):
                    expert = expert.to(args.devices[1])
                    score = expert.attempt(topo_embeds, handler.trn_loader.dataset)
                    if reca:
                        score *= expert_scores[expert_id]
                    assignment[dataset_id].append((expert_id, score))
                assignment[dataset_id].sort(key=lambda x: x[1], reverse=True)
            if log_assignment:
                print('\n----------\nAssignment')
                for dataset_id, handler in enumerate(handlers):
                    out = ''
                    for exp_idx in range(min(4, len(self.experts))):
                        out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) '
                    print(handler.data_name, out)
                print('----------\n')

            self.assignment = list(map(lambda x: x[0][0], assignment))
    
    def summon(self, dataset_id):
        return self.experts[self.assignment[dataset_id]]
    
    def summon_opt(self, dataset_id):
        return self.opts[self.assignment[dataset_id]]

================================================
FILE: node_classification/Utils/TimeLogger.py
================================================
import datetime

logmsg = ''
timemark = dict()
saveDefault = False
def log(msg, save=None, oneline=False):
	global logmsg
	global saveDefault
	time = datetime.datetime.now()
	tem = '%s: %s' % (time, msg)
	if save != None:
		if save:
			logmsg += tem + '\n'
	elif saveDefault:
		logmsg += tem + '\n'
	if oneline:
		print(tem, end='\r')
	else:
		print(tem)

def marktime(marker):
	global timemark
	timemark[marker] = datetime.datetime.now()


if __name__ == '__main__':
	log('')

================================================
FILE: node_classification/data_handler.py
================================================
import pickle
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
from params import args
import scipy.sparse as sp
from Utils.TimeLogger import log
import torch as t
import torch.utils.data as data
import torch_geometric.transforms as T
from model import Feat_Projector, Adj_Projector, TopoEncoder
import os

class MultiDataHandler:
    def __init__(self, trn_datasets, tst_datasets_group):
        all_datasets = trn_datasets
        all_tst_datasets = []
        for tst_datasets in tst_datasets_group:
            all_datasets = all_datasets + tst_datasets
            all_tst_datasets += tst_datasets
        all_datasets = list(set(all_datasets))
        all_datasets.sort()
        self.trn_handlers = []
        self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))]
        for data_name in all_datasets:
            trn_flag = data_name in trn_datasets
            tst_flag = data_name in all_tst_datasets
            handler = DataHandler(data_name)
            if trn_flag:
                self.trn_handlers.append(handler)
            if tst_flag:
                for i in range(len(tst_datasets_group)):
                    if data_name in tst_datasets_group[i]:
                        self.tst_handlers_group[i].append(handler)
        self.make_joint_trn_loader()
    
    def make_joint_trn_loader(self):
        loader_datasets = []
        for trn_handler in self.trn_handlers:
            tem_dataset = trn_handler.trn_loader.dataset
            loader_datasets.append(tem_dataset)
        joint_dataset = JointTrnData(loader_datasets)
        self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0)
    
    def remake_initial_projections(self):
        for i in range(len(self.trn_handlers)):
            trn_handler = self.trn_handlers[i]
            trn_handler.make_projectors()

class DataHandler:
    def __init__(self, data_name):
        self.data_name = data_name
        self.get_data_files()
        log(f'Loading dataset {data_name}')
        self.topo_encoder = TopoEncoder()
        self.load_data()
    
    def get_data_files(self):
        predir = f'/home/user_name/data/zero-shot datasets/node_data/{self.data_name}/'
        # predir = f'../handle_node_data/{self.data_name}/'
        if os.path.exists(predir + 'feats.pkl'):
            self.feat_file = predir + 'feats.pkl'
        else:
            self.feat_file = None
        self.trnfile = predir + 'trn_mat.pkl'
        self.tstfile = predir + 'tst_mat.pkl'
        self.fewshotfile = predir + 'fewshot_mat_{shot}.pkl'.format(shot=args.shot)
        self.valfile = predir + 'val_mat.pkl'

    def load_one_file(self, filename):
        with open(filename, 'rb') as fs:
            ret = (pickle.load(fs) != 0).astype(np.float32)
        if type(ret) != coo_matrix:
            ret = sp.coo_matrix(ret)
        return ret
    
    def load_feats(self, filename):
        try:
            with open(filename, 'rb') as fs:
                feats = pickle.load(fs)
        except Exception as e:
            print(filename + str(e))
            exit()
        return feats

    def normalize_adj(self, mat, log=False):
        degree = np.array(mat.sum(axis=-1))
        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
        if mat.shape[0] == mat.shape[1]:
            return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()
        else:
            tem = d_inv_sqrt_mat.dot(mat)
            col_degree = np.array(mat.sum(axis=0))
            d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])
            d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
            d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
            return tem.dot(d_inv_sqrt_mat).tocoo()
    
    def unique_numpy(self, row, col):
        hash_vals = row * args.node_num + col
        hash_vals = np.unique(hash_vals).astype(np.int64)
        col = hash_vals % args.node_num
        row = (hash_vals - col).astype(np.int64) // args.node_num
        return row, col

    def make_torch_adj(self, mat, unidirectional_for_asym=False):
        if mat.shape[0] == mat.shape[1]:
            _row = mat.row
            _col = mat.col
            row = np.concatenate([_row, _col]).astype(np.int64)
            col = np.concatenate([_col, _row]).astype(np.int64)
            row, col = self.unique_numpy(row, col)
            data = np.ones_like(row)
            mat = coo_matrix((data, (row, col)), mat.shape)
            if args.selfloop == 1:
                mat = (mat + sp.eye(mat.shape[0])) * 1.0
            normed_asym_mat = self.normalize_adj(mat)
            row = t.from_numpy(normed_asym_mat.row).long()
            col = t.from_numpy(normed_asym_mat.col).long()
            idxs = t.stack([row, col], dim=0)
            vals = t.from_numpy(normed_asym_mat.data).float()
            shape = t.Size(normed_asym_mat.shape)
            asym_adj = t.sparse.FloatTensor(idxs, vals, shape)
            return asym_adj
        elif unidirectional_for_asym:
            mat = (mat != 0) * 1.0
            mat = self.normalize_adj(mat, log=True)
            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
            vals = t.from_numpy(mat.data.astype(np.float32))
            shape = t.Size(mat.shape)
            return t.sparse.FloatTensor(idxs, vals, shape)
        else:
            # make ui adj
            a = sp.csr_matrix((args.user_num, args.user_num))
            b = sp.csr_matrix((args.item_num, args.item_num))
            mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
            mat = (mat != 0) * 1.0
            if args.selfloop == 1:
                mat = (mat + sp.eye(mat.shape[0])) * 1.0
            mat = self.normalize_adj(mat)

            # make cuda tensor
            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
            vals = t.from_numpy(mat.data.astype(np.float32))
            shape = t.Size(mat.shape)
            return t.sparse.FloatTensor(idxs, vals, shape)

    def load_data(self):
        tst_mat = self.load_one_file(self.tstfile)
        trn_mat = self.load_one_file(self.trnfile)
        # fewshot_mat = self.load_one_file(self.fewshotfile)
        if self.feat_file is not None:
            self.feats = t.from_numpy(self.load_feats(self.feat_file)).float()
            self.feats = self.feats
            args.featdim = self.feats.shape[1]
        else:
            self.feats = None
            args.featdim = args.latdim

        if trn_mat.shape[0] != trn_mat.shape[1]:
            args.user_num, args.item_num = trn_mat.shape
            args.node_num = args.user_num + args.item_num
            print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz))
        else:
            args.node_num = trn_mat.shape[0]
            print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+tst_mat.nnz))
        if args.tst_mode == 'tst':
            tst_data = NodeTstData(tst_mat)
            self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)
            # tst_loss_data = TrnData(tst_mat)
            # self.tst_loss_loader = data.DataLoader(tst_loss_data, batch_size=args.batch, shuffle=False, num_workers=0)
            self.tst_input_adj = self.make_torch_adj(trn_mat)
        elif args.tst_mode == 'val':
            raise Exception('Validation not made for node classification')
        else:
            raise Exception('Specify proper test mode')


        self.trn_mat = trn_mat
        trn_data = TrnData(self.trn_mat)
        self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)
        if args.tst_mode == 'tst':
            self.trn_input_adj = self.tst_input_adj
        else:
            self.trn_input_adj = self.make_torch_adj(trn_mat)

        if self.trn_mat.shape[0] == self.trn_mat.shape[1]:
            self.asym_adj = self.trn_input_adj
        else:
            self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True)
        self.make_projectors()
        self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps)
        self.ratio_500_all = 250 / len(self.trn_loader)
    
    def make_projectors(self):
        with t.no_grad():
            projectors = []
            if args.proj_method == 'adj_svd' or args.proj_method == 'both':
                tem = self.asym_adj.to(args.devices[0])
                projectors = [Adj_Projector(tem)]
            if self.feats is not None and args.proj_method != 'adj_svd':
                tem = self.feats.to(args.devices[0])
                projectors.append(Feat_Projector(tem))
            assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot'
            feats = projectors[0]()
            if len(projectors) == 2:
                feats2 = projectors[1]()
                feats = feats + feats2

            try:
                self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu()
            except Exception:
                print(f'{self.data_name} memory overflow')
                mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True)
                tem_adj = self.trn_input_adj.to(args.devices[0])
                mem_cache = 256
                projectors_list = []
                for i in range(feats.shape[1] // mem_cache):
                    st, ed = i * mem_cache, (i + 1) * mem_cache
                    tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8)
                    tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu()
                    projectors_list.append(tem_feats)
                self.projectors = t.concat(projectors_list, dim=-1)
            t.cuda.empty_cache()

class TrnData(data.Dataset):
    def __init__(self, coomat):
        self.ancs, self.poss = coomat.row, coomat.col
        # self.dokmat = set(list(map(lambda idx: (self.ancs[idx], self.poss[idx]), range(len(self.ancs)))))
        # self.dokmat = coomat.todok()
        self.negs = np.zeros(len(self.ancs)).astype(np.int32)
        self.cand_num = coomat.shape[1]
        self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0]
        self.poss = coomat.col + self.neg_shift
        self.neg_sampling()
    
    def neg_sampling(self):
        self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0])
        # self.negs = np.zeros_like(self.ancs)
        # for i in range(len(self.ancs)):
        #     u = self.ancs[i]
        #     while True:
        #         i_neg = np.random.randint(self.cand_num)
        #         if (u, i_neg) not in self.dokmat:
        #             break
        #     self.negs[i] = i_neg
        # self.negs += self.neg_shift

    def __len__(self):
        return len(self.ancs)
    
    def __getitem__(self, idx):
        return self.ancs[idx], self.poss[idx] , self.negs[idx]

class NodeTstData(data.Dataset):
    def __init__(self, tst_mat):
        self.class_num = tst_mat.shape[1]
        self.nodes, self.labels = tst_mat.row, tst_mat.col

    def __len__(self):
        return len(self.nodes)
    
    def __getitem__(self, idx):
        return self.nodes[idx], self.labels[idx]

class JointTrnData(data.Dataset):
    def __init__(self, dataset_list):
        self.batch_dataset_ids = []
        self.batch_st_ed_list = []
        self.dataset_list = dataset_list
        for dataset_id, dataset in enumerate(dataset_list):
            samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0)
            for j in range(samp_num):
                self.batch_dataset_ids.append(dataset_id)
                st = j * args.batch
                ed = min((j + 1) * args.batch, len(dataset))
                self.batch_st_ed_list.append((st, ed))
    
    def neg_sampling(self):
        for dataset in self.dataset_list:
            dataset.neg_sampling()

    def __len__(self):
        return len(self.batch_dataset_ids)
    
    def __getitem__(self, idx):
        st, ed = self.batch_st_ed_list[idx]
        dataset_id = self.batch_dataset_ids[idx]
        return *self.dataset_list[dataset_id][st: ed], dataset_id


================================================
FILE: node_classification/main.py
================================================
import torch as t
from torch import nn
import Utils.TimeLogger as logger
from Utils.TimeLogger import log
from params import args
from model import Expert, Feat_Projector, Adj_Projector, AnyGraph
from data_handler import MultiDataHandler, DataHandler
import numpy as np
import pickle
import os
import setproctitle
import time
from sklearn.metrics import f1_score

class Exp:
    def __init__(self, multi_handler):
        self.multi_handler = multi_handler
        print(list(map(lambda x: x.data_name, multi_handler.trn_handlers)))
        for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group):
            print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers)))
        self.metrics = dict()
        trn_mets = ['Loss', 'preLoss']
        tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss']
        mets = trn_mets + tst_mets
        for met in mets:
            if met in trn_mets:
                self.metrics['Train' + met] = list()
            if met in tst_mets:
                for i in range(len(self.multi_handler.tst_handlers_group)):
                    self.metrics['Test' + str(i) + met] = list()
        
    def make_print(self, name, ep, reses, save, data_name=None):
        if data_name is None:
            ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
        else:
            ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)
        for metric in reses:
            val = reses[metric]
            ret += '%s = %.4f, ' % (metric, val)
            tem = name + metric if data_name is None else name + data_name + metric
            if save and tem in self.metrics:
                self.metrics[tem].append(val)
        ret = ret[:-2] + '      '
        return ret
    
    def run(self):
        self.prepare_model()
        log('Model Prepared')
        stloc = 0
        if args.load_model != None:
            self.load_model()
            stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)
        best_ndcg, best_ep = 0, -1
        early_stop = False
        for ep in range(stloc, args.epoch):
            tst_flag = (ep % args.tst_epoch == 0)
            start_time = time.time()
            self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True)
            reses = self.train_epoch()
            log(self.make_print('Train', ep, reses, tst_flag))
            self.multi_handler.remake_initial_projections()
            end_time = time.time()
            print(f'NOTICE: {end_time-start_time}')
            if tst_flag:
                for handler_group_id in range(len(self.multi_handler.tst_handlers_group)):
                    tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id]
                    self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)
                    recall, ndcg, tstnum = 0, 0, 0
                    for i, handler in enumerate(tst_handlers):
                        reses = self.test_epoch(handler, i)
                        log(self.make_print('handler.data_name', ep, reses, False))
                        recall += reses['Recall'] * reses['tstNum']
                        ndcg += reses['NDCG'] * reses['tstNum']
                        tstnum += reses['tstNum']
                    reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum}
                    log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag))

                    if reses['NDCG'] > best_ndcg:
                        best_ndcg = reses['NDCG']
                        best_ep = ep
                self.save_history()
            print()

        for test_group_id in range(len(self.multi_handler.tst_handlers_group)):
            repeat_times = 10
            overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times)
            overall_tstnum = 0
            tst_handlers = self.multi_handler.tst_handlers_group[test_group_id]
            if args.assignment == 'one-graph-one-expert':
                self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)
            for i, handler in enumerate(tst_handlers):
                    mets = dict()
                    for _ in range(repeat_times):
                        handler.make_projectors()
                        if not args.assignment == 'one-graph-one-expert':
                            self.model.assign_experts([handler], reca=False, log_assignment=False)
                        reses = self.test_epoch(handler, i if args.assignment == 'one-graph-one-expert' else 0)
                        log(self.make_print('Test', args.epoch, reses, False))
                        for met in reses:
                            if met not in mets:
                                mets[met] = []
                            mets[met].append(reses[met])
                    tstnum = reses['tstNum']
                    tot_reses = dict()
                    for met in reses:
                        tem_arr = np.array(mets[met])
                        tot_reses[met + '_std'] = tem_arr.std()
                        tot_reses[met + '_mean'] = tem_arr.mean()
                    
                    overall_recall += np.array(mets['Acc']) * tstnum
                    overall_ndcg += np.array(mets['F1']) * tstnum
                    overall_tstnum += tstnum
                    log(self.make_print(f'Test', args.epoch, tot_reses, False, handler.data_name))
            overall_recall /= overall_tstnum
            overall_ndcg /= overall_tstnum
            overall_res = dict()
            overall_res['Recall_mean'] = overall_recall.mean()
            overall_res['Recall_std'] = overall_recall.std()
            overall_res['NDCG_mean'] = overall_ndcg.mean()
            overall_res['NDCG_std'] = overall_ndcg.std()
            log(self.make_print('Overall Test', args.epoch, overall_res, False))
        self.save_history()

    def print_model_size(self):
        total_params = 0
        trainable_params = 0
        non_trainable_params = 0
        for param in self.model.parameters():
            tem = np.prod(param.size())
            total_params += tem
            if param.requires_grad:
                trainable_params += tem
            else:
                non_trainable_params += tem
        print(f'Total params: {total_params/1e6}')
        print(f'Trainable params: {trainable_params/1e6}')
        print(f'Non-trainable params: {non_trainable_params/1e6}')

    def prepare_model(self):
        self.model = AnyGraph()
        t.cuda.empty_cache()
        self.print_model_size()

    def train_epoch(self):
        self.model.train()
        trn_loader = self.multi_handler.joint_trn_loader
        trn_loader.dataset.neg_sampling()
        ep_loss, ep_preloss, ep_regloss = 0, 0, 0
        steps = len(trn_loader)
        tot_samp_num = 0
        counter = [0] * len(self.multi_handler.trn_handlers)
        reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers)))
        for i, batch_data in enumerate(trn_loader):
            if args.epoch_max_step > 0 and i >= args.epoch_max_step:
                break
            ancs, poss, negs, dataset_id = batch_data
            ancs = ancs[0].long()
            poss = poss[0].long()
            negs = negs[0].long()
            dataset_id = dataset_id[0].long()
            tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all
            if tem_bar < 1.0 and np.random.uniform() > tem_bar:
                steps -= 1
                continue

            expert = self.model.summon(dataset_id)#.cuda()
            opt = self.model.summon_opt(dataset_id)
            # adj = self.multi_handler.trn_handlers[dataset_id].trn_input_adj
            feats = self.multi_handler.trn_handlers[dataset_id].projectors
            loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)
            opt.zero_grad()
            loss.backward()
            # nn.utils.clip_grad_norm_(expert.parameters(), max_norm=20, norm_type=2)
            opt.step()

            sample_num = ancs.shape[0]
            tot_samp_num += sample_num
            ep_loss += loss.item() * sample_num
            ep_preloss += loss_dict['preloss'].item() * sample_num
            ep_regloss += loss_dict['regloss'].item()
            log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f        ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)

            counter[dataset_id] += 1
            if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0:
            # if args.proj_trn_steps > 0 and counter[dataset_id] >= args.proj_trn_steps:
                self.multi_handler.trn_handlers[dataset_id].make_projectors()
            if (i + 1) % reassign_steps == 0:
                self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False)
        ret = dict()
        ret['Loss'] = ep_loss / tot_samp_num
        ret['preLoss'] = ep_preloss / tot_samp_num
        ret['regLoss'] = ep_regloss / steps
        t.cuda.empty_cache()
        return ret
    
    def make_trn_masks(self, numpy_usrs, csr_mat):
        trn_masks = csr_mat[numpy_usrs].tocoo()
        cand_size = trn_masks.shape[1]
        trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long()
        return trn_masks, cand_size

    def test_loss_epoch(self, handler, dataset_id):
        with t.no_grad():
            tst_loader = handler.tst_loss_loader
            self.model.eval()
            expert = self.model.summon(dataset_id)#.cuda()
            ep_loss, ep_preloss, ep_regloss = 0, 0, 0
            steps = len(tst_loader)
            tot_samp_num = 0
            for i, batch_data in enumerate(tst_loader):
                ancs, poss, negs = batch_data
                ancs = ancs.long()
                poss = poss.long()
                negs = negs.long()
                # adj = handler.tst_input_adj
                feats = handler.projectors
                loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)
                
                sample_num = ancs.shape[0]
                tot_samp_num += sample_num
                ep_loss += loss.item() * sample_num
                ep_preloss += loss_dict['preloss'].item() * sample_num
                ep_regloss += loss_dict['regloss'].item()
                log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f        ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)

        ret = dict()
        ret['Loss'] = ep_loss / tot_samp_num
        ret['preLoss'] = ep_preloss / tot_samp_num
        ret['regLoss'] = ep_regloss / steps
        ret['tot_samp_num'] = tot_samp_num
        t.cuda.empty_cache()
        return ret
    
    def test_epoch(self, handler, dataset_id):
        with t.no_grad():
            tst_loader = handler.tst_loader
            class_num = tst_loader.dataset.class_num
            self.model.eval()
            expert = self.model.summon(dataset_id)
            ep_acc, ep_tot = 0, 0
            steps = len(tst_loader)
            for i, batch_data in enumerate(tst_loader):
                nodes, labels = list(map(lambda x: x.long().cuda(), batch_data))
                feats = handler.projectors
                preds = expert.pred_for_node_test(nodes, class_num, feats, rerun_embed=False if i!=0 else True)
                if i == 0:
                    all_preds, all_labels = preds, labels
                else:
                    all_preds = t.concatenate([all_preds, preds])
                    all_labels = t.concatenate([all_labels, labels])
                hit = (labels == preds).float().sum().item()
                ep_acc += hit
                ep_tot += labels.shape[0]
                log('Steps %d/%d: hit = %d, tot = %d          ' % (i, steps, ep_acc, ep_tot), save=False, oneline=True)
        ret = dict()
        ret['Acc'] = ep_acc / ep_tot
        ret['F1'] = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
        ret['tstNum'] = ep_tot
        t.cuda.empty_cache()
        return ret

    
    def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
        assert topLocs.shape[0] == len(batIds)
        allRecall = allNdcg = 0
        for i in range(len(batIds)):
            temTopLocs = list(topLocs[i])
            temTstLocs = tstLocs[batIds[i]]
            tstNum = len(temTstLocs)
            maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])
            recall = dcg = 0
            for val in temTstLocs:
                if val in temTopLocs:
                    recall += 1
                    dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
            recall = recall / tstNum
            ndcg = dcg / maxDcg
            allRecall += recall
            allNdcg += ndcg
        return allRecall, allNdcg
    
    def save_history(self):
        if args.epoch == 0:
            return
        with open('../History/' + args.save_path + '.his', 'wb') as fs:
            pickle.dump(self.metrics, fs)

        content = {
            'model': self.model,
        }
        t.save(content, '../Models/' + args.save_path + '.mod')
        log('Model Saved: %s' % args.save_path)

    def load_model(self):
        ckp = t.load('../Models/' + args.load_model + '.mod')
        self.model = ckp['model']
        # self.model.set_initial_projection(self.handler.torch_adj)
        self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)

        with open('../History/' + args.load_model + '.his', 'rb') as fs:
            self.metrics = pickle.load(fs)
        log('Model Loaded')

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    if len(args.gpu.split(',')) == 2:
        args.devices = ['cuda:0', 'cuda:1']
    elif len(args.gpu.split(',')) > 2:
        raise Exception('Devices should be less than 2')
    else:
        args.devices = ['cuda:0', 'cuda:0']
    logger.saveDefault = True
    setproctitle.setproctitle('akaxia_AutoGraph')

    log('Start')

    
    datasets = dict()
    datasets['all'] = [
        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1',
    ]
    datasets['ecommerce'] = [
        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech'
    ]
    datasets['academic'] = [
        'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab'
    ]
    datasets['others'] = [
        'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'
    ]
    datasets['div1'] = [
        'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron',
    ]
    datasets['div2'] = [
        'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA',
    ]
    datasets['node'] = [
        'cora', 'arxiv', 'pubmed', 'home', 'tech'
    ]

    if args.dataset_setting in datasets.keys():
        trn_datasets = tst_datasets = datasets[args.dataset_setting]
    elif args.dataset_setting in datasets['all']:
        trn_datasets = tst_datasets = [args.dataset_setting]
    elif '+' in args.dataset_setting:
        idx = args.dataset_setting.index('+')
        trn_datasets = datasets[args.dataset_setting[:idx]]
        tst_datasets = datasets[args.dataset_setting[idx+1:]]
    elif '_in_' in args.dataset_setting:
        idx = args.dataset_setting.index('_in_')
        tst_datasets_1 = datasets[args.dataset_setting[:idx]]
        tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]]
        tst_datasets = []
        for data in tst_datasets_1:
            if data in tst_datasets_2:
                tst_datasets.append(data)
        trn_datasets = tst_datasets

    # trn_datasets = tst_datasets = ['products_home']
    if '+' not in args.dataset_setting:
        handler = MultiDataHandler(trn_datasets, [tst_datasets])
    else:
        handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets])
    log('Load Data')

    exp = Exp(handler)
    exp.run()
    print(args.load_model, args.dataset_setting)


================================================
FILE: node_classification/model.py
================================================
import torch as t
from torch import nn
import torch.nn.functional as F
from params import args
import numpy as np
from Utils.TimeLogger import log
from torch.nn import MultiheadAttention
from time import time

init = nn.init.xavier_uniform_
# init = nn.init.normal_
uniformInit = nn.init.uniform_

class FeedForwardLayer(nn.Module):
    def __init__(self, in_feat, out_feat, bias=True, act=None):
        super(FeedForwardLayer, self).__init__()
        self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)
        if act == 'identity' or act is None:
            self.act = None
        elif act == 'leaky':
            self.act = nn.LeakyReLU(negative_slope=args.leaky)
        elif act == 'relu':
            self.act = nn.ReLU()
        elif act == 'relu6':
            self.act = nn.ReLU6()
        else:
            raise Exception('Error')
    
    def forward(self, embeds):
        if self.act is None:
            return self.linear(embeds)
        return self.act(self.linear(embeds))

class TopoEncoder(nn.Module):
    def __init__(self):
        super(TopoEncoder, self).__init__()

        self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)

    def forward(self, adj, embeds, normed=False):
        with t.no_grad():
            if not normed:
                embeds = self.layer_norm(embeds)
            # embeds_list = []
            final_embeds = 0
            if args.gnn_layer == 0:
                final_embeds = embeds
                # embeds_list.append(embeds)
            for _ in range(args.gnn_layer):
                embeds = t.spmm(adj, embeds)
                final_embeds += embeds
                # embeds_list.append(embeds)
            embeds = final_embeds#sum(embeds_list)
        return embeds

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)])
        self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)])
        self.dropout = nn.Dropout(p=args.drop_rate)
    
    def forward(self, embeds):
        for i in range(args.fc_layer):
            embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds)
        return embeds

class GTLayer(nn.Module):
    def __init__(self):
        super(GTLayer, self).__init__()
        self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)
        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False
        self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
        self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)
        self.fc_dropout = nn.Dropout(p=args.drop_rate)
    
    def _pick_anchors(self, embeds):
        perm = t.randperm(embeds.shape[0])
        anchors = perm[:args.anchor]
        return embeds[anchors]
    
    def forward(self, embeds):
        anchor_embeds = self._pick_anchors(embeds)
        _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)
        anchor_embeds = _anchor_embeds + anchor_embeds
        _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)
        embeds = self.layer_norm1(_embeds + embeds)
        _embeds = self.fc_dropout(self.dense_layers(embeds))
        embeds = (self.layer_norm2(_embeds + embeds))
        return embeds

class GraphTransformer(nn.Module):
    def __init__(self):
        super(GraphTransformer, self).__init__()
        self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])

    def forward(self, embeds):
        for i, layer in enumerate(self.gt_layers):
            embeds = layer(embeds) / args.scale_layer
        return embeds

class Feat_Projector(nn.Module):
    def __init__(self, feats):
        super(Feat_Projector, self).__init__()

        if args.proj_method == 'uniform':
            self.proj_embeds = self.uniform_proj(feats)
        elif args.proj_method == 'svd' or args.proj_method == 'both':
            self.proj_embeds = self.svd_proj(feats)
        elif args.proj_method == 'random':
            self.proj_embeds = self.random_proj(feats)
        elif args.proj_method == 'original':
            self.proj_embeds = feats
        self.proj_embeds = t.flip(self.proj_embeds, dims=[-1])
        self.proj_embeds = self.proj_embeds.detach()
    
    def svd_proj(self, feats):
        if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]:
            dim = min(feats.shape[0], feats.shape[1])
            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter)
            decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
            s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])])
        else:
            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter)
        decom_feats = decom_feats @ t.diag(t.sqrt(s))
        return decom_feats.cpu()
    
    def uniform_proj(self, feats):
        projection = init(t.empty(args.featdim, args.latdim))
        return feats @ projection
    
    def random_proj(self, feats):
        projection = init(t.empty(feats.shape[0], args.latdim))
        return projection
    
    def forward(self):
        return self.proj_embeds

class Adj_Projector(nn.Module):
    def __init__(self, adj):
        super(Adj_Projector, self).__init__()

        if args.proj_method == 'adj_svd' or args.proj_method == 'both':
            self.proj_embeds = self.svd_proj(adj)
        self.proj_embeds = self.proj_embeds.detach()
    
    def svd_proj(self, adj):
        q = args.latdim
        if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:
            dim = min(adj.shape[0], adj.shape[1])
            svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)
            svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
            svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)
            s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])
        else:
            svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)
        svd_u = svd_u @ t.diag(t.sqrt(s))
        svd_v = svd_v @ t.diag(t.sqrt(s))
        if adj.shape[0] != adj.shape[1]:
            projection = t.concat([svd_u, svd_v], dim=0)
        else:
            projection = svd_u + svd_v
        return projection.cpu()
    
    def forward(self):
        return self.proj_embeds

class Expert(nn.Module):
    def __init__(self):
        super(Expert, self).__init__()
        
        self.topo_encoder = TopoEncoder().to(args.devices[0])
        if args.nn == 'mlp':
            self.trainable_nn = MLP().to(args.devices[1])
        else:
            self.trainable_nn = GraphTransformer().to(args.devices[1])
        self.trn_count = 1
    
    def forward(self, projectors, pck_nodes=None):
        embeds = projectors.to(args.devices[1])
        if pck_nodes is not None:
            embeds = embeds[pck_nodes]
        embeds = self.trainable_nn(embeds)
        return embeds

    def pred_norm(self, pos_preds, neg_preds):
        pos_preds_num = pos_preds.shape[0]
        neg_preds_shape = neg_preds.shape
        preds = t.concat([pos_preds, neg_preds.view(-1)])
        preds = preds - preds.max()
        pos_preds = preds[:pos_preds_num]
        neg_preds = preds[pos_preds_num:].view(neg_preds_shape)
        return pos_preds, neg_preds
    
    def cal_loss(self, batch_data, projectors):
        ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data))
        self.trn_count += ancs.shape[0]
        pck_nodes = t.concat([ancs, poss, negs])
        final_embeds = self.forward(projectors, pck_nodes)
        # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]
        anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3)
        if final_embeds.isinf().any() or final_embeds.isnan().any():
            raise Exception('Final embedding fails')
        
        if args.loss == 'ce':
            pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)
            if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():
                raise Exception('Preds fails')
            pos_loss = pos_preds
            neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()
            pre_loss = -(pos_loss - neg_loss).mean()
        elif args.loss == 'bpr':
            pos_preds = (anc_embeds * pos_embeds).sum(-1)
            neg_preds = (anc_embeds * neg_embeds).sum(-1)
            pos_loss, neg_loss = pos_preds, neg_preds
            pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean() 

        if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():
            raise Exception('NaN or Inf')

        reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))
        loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}
        return pre_loss + reg_loss, loss_dict
    
    def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True):
        ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data))
        if rerun_embed:
            try:
                final_embeds = self.forward(projectors)
            except Exception:
                final_embeds_list = []
                div = args.batch * 3
                temlen = projectors.shape[0] // div
                for i in range(temlen):
                    st, ed = div * i, div * (i + 1)
                    tem_projectors = projectors[st: ed, :]
                    final_embeds_list.append(self.forward(tem_projectors))
                if temlen * div < projectors.shape[0]:
                    tem_projectors = projectors[temlen*div:, :]
                    final_embeds_list.append(self.forward(tem_projectors))
                final_embeds = t.concat(final_embeds_list, dim=0)
            self.final_embeds = final_embeds
        final_embeds = self.final_embeds
        anc_embeds = final_embeds[ancs]
        cand_embeds = final_embeds[-cand_size:]

        mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size]))
        dense_mat = mask_mat.to_dense()
        all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8
        return all_preds
    
    def pred_for_node_test(self, nodes, cand_size, feats, rerun_embed=True):
        if rerun_embed:
            final_embeds = self.forward(feats)
            self.final_embeds = final_embeds
        final_embeds = self.final_embeds
        anc_embeds = final_embeds[nodes]
        class_embeds = final_embeds[-cand_size:]
        preds = anc_embeds @ class_embeds.T
        return t.argmax(preds, dim=-1)

    def attempt(self, topo_embeds, dataset):
        final_embeds = self.trainable_nn(topo_embeds)
        rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs]))
        if rows.shape[0] > args.attempt_cache:
            random_perm = t.randperm(rows.shape[0], device=args.devices[0])
            pck_perm = random_perm[:args.attempt_cache]
            rows = rows[pck_perm]
            cols = cols[pck_perm]
            negs = negs[pck_perm]
        while True:
            try:
                row_embeds = final_embeds[rows]
                col_embeds = final_embeds[cols]
                neg_embeds = final_embeds[negs]
                score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item()
                break
            except Exception:
                args.attempt_cache = args.attempt_cache // 2
                random_perm = t.randperm(rows.shape[0], device=args.devices[0])
                pck_perm = random_perm[:args.attempt_cache]
                rows = rows[pck_perm]
                cols = cols[pck_perm]
                negs = negs[pck_perm]
        t.cuda.empty_cache()
        return score

class AnyGraph(nn.Module):
    def __init__(self):
        super(AnyGraph, self).__init__()
        self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda()
        self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts))

    def one_graph_one_expert(self, handlers, log_assignment=False):
        self.assignment = [i for i in range(len(handlers))]
        if log_assignment:
            print('\n----------\nAssignment')
            for dataset_id, handler in enumerate(handlers):
                print(handler.data_name, f'{self.assignment[dataset_id]}')
    
    def store_history_assignment(self, assignment):
        pass
        
    def assign_experts(self, handlers, reca=True, log_assignment=False):
        if args.expert_num == 1:
            self.assignment = [0] * len(handlers)
            return
        if args.assignment == 'one-graph-one-expert':
            self.one_graph_one_expert(handlers, log_assignment)
            return
        if args.assignment not in ['top1', 'top2']:
            raise Exception('Unrecognized assigning methods')
        try:
            expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts)))
            expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2
        except Exception:
            expert_scores = np.ones(len(self.experts))
        with t.no_grad():
            assignment = [list() for i in range(len(handlers))]
            for dataset_id, handler in enumerate(handlers):
                topo_embeds = handler.projectors.to(args.devices[1])
                for expert_id, expert in enumerate(self.experts):
                    expert = expert.to(args.devices[1])
                    score = expert.attempt(topo_embeds, handler.trn_loader.dataset)
                    if reca:
                        score *= expert_scores[expert_id]
                    assignment[dataset_id].append((expert_id, score))
                assignment[dataset_id].sort(key=lambda x: x[1], reverse=True)
                if args.assignment == 'top2':
                    if assignment[dataset_id][0][1] - assignment[dataset_id][1][1] < 0.05 and np.random.uniform() > 0.5:
                        tem = assignment[dataset_id][0]
                        assignment[dataset_id][0] = assignment[dataset_id][1]
                        assignment[dataset_id][1] = tem
            if log_assignment:
                print('\n----------\nAssignment')
                for dataset_id, handler in enumerate(handlers):
                    out = ''
                    for exp_idx in range(min(4, len(self.experts))):
                        out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) '
                    print(handler.data_name, out)
                print('----------\n')

            self.assignment = list(map(lambda x: x[0][0], assignment))
    
    def summon(self, dataset_id):
        return self.experts[self.assignment[dataset_id]]
    
    def summon_opt(self, dataset_id):
        return self.opts[self.assignment[dataset_id]]


================================================
FILE: node_classification/params.py
================================================
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Model Parameters')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--batch', default=4096, type=int, help='training batch size')
    parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')
    parser.add_argument('--epoch', default=100, type=int, help='number of epochs')
    parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
    parser.add_argument('--load_model', default=None, help='model name to load')
    parser.add_argument('--data', default='ml1m', type=str, help='name of dataset')
    parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')
    parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')
    parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')
    parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')
    parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]')
    parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]')
    parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance')
    parser.add_argument('--ratio_fewshot_set', default=0.5, type=float, help='ratio of fewshot set')
    parser.add_argument('--shot', default=5, type=int, help='number of shots for each node')
    parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all')

    parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')
    parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality')
    parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers')
    parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers')
    parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers')
    parser.add_argument('--head', default=4, type=int, help='number of attention heads')
    parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')
    parser.add_argument('--act', default='relu', type=str, help='activation function')
    parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use')
    parser.add_argument('--assignment', default='top1', type=str, help='assigning method')
    parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')
    parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')
    parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout')
    parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration')
    parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')
    parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd')
    parser.add_argument('--expert_num', default=8, type=int, help='number of experts')
    parser.add_argument('--loss', default='ce', type=str, help='loss function')
    parser.add_argument('--proj_method', default='both', type=str, help='feature projection method')
    parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use')
    parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection')
    parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection')
    return parser.parse_args()
args = parse_args()

================================================
FILE: params.py
================================================
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Model Parameters')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--batch', default=4096, type=int, help='training batch size')
    parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')
    parser.add_argument('--epoch', default=100, type=int, help='number of epochs')
    parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
    parser.add_argument('--load_model', default=None, help='model name to load')
    parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')
    parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')
    parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')
    parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')
    parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]')
    parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]')
    parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance')
    parser.add_argument('--ratio_fewshot', default=0.1, type=float, help='ratio of fewshot set')
    parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all')

    parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')
    parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality')
    parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers')
    parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers')
    parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers')
    parser.add_argument('--head', default=4, type=int, help='number of attention heads')
    parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')
    parser.add_argument('--act', default='relu', type=str, help='activation function')
    parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use')
    parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')
    parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')
    parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout')
    parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration')
    parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')
    parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd')
    parser.add_argument('--expert_num', default=8, type=int, help='number of experts')
    parser.add_argument('--loss', default='ce', type=str, help='loss function')
    parser.add_argument('--proj_method', default='both', type=str, help='feature projection method')
    parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use')
    parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection')
    parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection')
    return parser.parse_args()
args = parse_args()
Download .txt
gitextract_t0oakqy3/

├── .gitignore
├── History/
│   ├── pretrain_link1.his
│   └── pretrain_link2.his
├── Models/
│   └── README.md
├── README.md
├── Utils/
│   └── TimeLogger.py
├── data_handler.py
├── main.py
├── model.py
├── node_classification/
│   ├── Utils/
│   │   └── TimeLogger.py
│   ├── data_handler.py
│   ├── main.py
│   ├── model.py
│   └── params.py
└── params.py
Download .txt
SYMBOL INDEX (166 symbols across 10 files)

FILE: Utils/TimeLogger.py
  function log (line 6) | def log(msg, save=None, oneline=False):
  function marktime (line 21) | def marktime(marker):

FILE: data_handler.py
  class MultiDataHandler (line 13) | class MultiDataHandler:
    method __init__ (line 14) | def __init__(self, trn_datasets, tst_datasets_group):
    method make_joint_trn_loader (line 36) | def make_joint_trn_loader(self):
    method remake_initial_projections (line 44) | def remake_initial_projections(self):
  class DataHandler (line 49) | class DataHandler:
    method __init__ (line 50) | def __init__(self, data_name):
    method get_data_files (line 57) | def get_data_files(self):
    method load_one_file (line 68) | def load_one_file(self, filename):
    method load_feats (line 75) | def load_feats(self, filename):
    method normalize_adj (line 84) | def normalize_adj(self, mat, log=False):
    method unique_numpy (line 99) | def unique_numpy(self, row, col):
    method make_torch_adj (line 106) | def make_torch_adj(self, mat, unidirectional_for_asym=False):
    method load_data (line 148) | def load_data(self):
    method make_projectors (line 207) | def make_projectors(self):
  class TstData (line 238) | class TstData(data.Dataset):
    method __init__ (line 239) | def __init__(self, coomat, trn_mat):
    method __len__ (line 254) | def __len__(self):
    method __getitem__ (line 257) | def __getitem__(self, idx):
  class TrnData (line 260) | class TrnData(data.Dataset):
    method __init__ (line 261) | def __init__(self, coomat):
    method neg_sampling (line 269) | def neg_sampling(self):
    method __len__ (line 272) | def __len__(self):
    method __getitem__ (line 275) | def __getitem__(self, idx):
  class JointTrnData (line 278) | class JointTrnData(data.Dataset):
    method __init__ (line 279) | def __init__(self, dataset_list):
    method neg_sampling (line 291) | def neg_sampling(self):
    method __len__ (line 295) | def __len__(self):
    method __getitem__ (line 298) | def __getitem__(self, idx):

FILE: main.py
  class Exp (line 14) | class Exp:
    method __init__ (line 15) | def __init__(self, multi_handler):
    method make_print (line 31) | def make_print(self, name, ep, reses, save, data_name=None):
    method run (line 45) | def run(self):
    method print_model_size (line 120) | def print_model_size(self):
    method prepare_model (line 135) | def prepare_model(self):
    method train_epoch (line 140) | def train_epoch(self):
    method make_trn_masks (line 189) | def make_trn_masks(self, numpy_usrs, csr_mat):
    method test_epoch (line 195) | def test_epoch(self, handler, dataset_id):
    method calc_recall_ndcg (line 226) | def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
    method save_history (line 245) | def save_history(self):
    method load_model (line 257) | def load_model(self):

FILE: model.py
  class FeedForwardLayer (line 13) | class FeedForwardLayer(nn.Module):
    method __init__ (line 14) | def __init__(self, in_feat, out_feat, bias=True, act=None):
    method forward (line 28) | def forward(self, embeds):
  class TopoEncoder (line 33) | class TopoEncoder(nn.Module):
    method __init__ (line 34) | def __init__(self):
    method forward (line 39) | def forward(self, adj, embeds, normed=False):
  class MLP (line 55) | class MLP(nn.Module):
    method __init__ (line 56) | def __init__(self):
    method forward (line 62) | def forward(self, embeds):
  class GTLayer (line 67) | class GTLayer(nn.Module):
    method __init__ (line 68) | def __init__(self):
    method _pick_anchors (line 76) | def _pick_anchors(self, embeds):
    method forward (line 81) | def forward(self, embeds):
  class GraphTransformer (line 91) | class GraphTransformer(nn.Module):
    method __init__ (line 92) | def __init__(self):
    method forward (line 96) | def forward(self, embeds):
  class Feat_Projector (line 101) | class Feat_Projector(nn.Module):
    method __init__ (line 102) | def __init__(self, feats):
    method svd_proj (line 116) | def svd_proj(self, feats):
    method uniform_proj (line 127) | def uniform_proj(self, feats):
    method random_proj (line 131) | def random_proj(self, feats):
    method forward (line 135) | def forward(self):
  class Adj_Projector (line 138) | class Adj_Projector(nn.Module):
    method __init__ (line 139) | def __init__(self, adj):
    method svd_proj (line 146) | def svd_proj(self, adj):
    method forward (line 164) | def forward(self):
  class Expert (line 167) | class Expert(nn.Module):
    method __init__ (line 168) | def __init__(self):
    method forward (line 178) | def forward(self, projectors, pck_nodes=None):
    method pred_norm (line 185) | def pred_norm(self, pos_preds, neg_preds):
    method cal_loss (line 194) | def cal_loss(self, batch_data, projectors):
    method pred_for_test (line 224) | def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed...
    method attempt (line 251) | def attempt(self, topo_embeds, dataset):
  class AnyGraph (line 277) | class AnyGraph(nn.Module):
    method __init__ (line 278) | def __init__(self):
    method assign_experts (line 283) | def assign_experts(self, handlers, reca=True, log_assignment=False):
    method summon (line 314) | def summon(self, dataset_id):
    method summon_opt (line 317) | def summon_opt(self, dataset_id):

FILE: node_classification/Utils/TimeLogger.py
  function log (line 6) | def log(msg, save=None, oneline=False):
  function marktime (line 21) | def marktime(marker):

FILE: node_classification/data_handler.py
  class MultiDataHandler (line 13) | class MultiDataHandler:
    method __init__ (line 14) | def __init__(self, trn_datasets, tst_datasets_group):
    method make_joint_trn_loader (line 36) | def make_joint_trn_loader(self):
    method remake_initial_projections (line 44) | def remake_initial_projections(self):
  class DataHandler (line 49) | class DataHandler:
    method __init__ (line 50) | def __init__(self, data_name):
    method get_data_files (line 57) | def get_data_files(self):
    method load_one_file (line 69) | def load_one_file(self, filename):
    method load_feats (line 76) | def load_feats(self, filename):
    method normalize_adj (line 85) | def normalize_adj(self, mat, log=False):
    method unique_numpy (line 100) | def unique_numpy(self, row, col):
    method make_torch_adj (line 107) | def make_torch_adj(self, mat, unidirectional_for_asym=False):
    method load_data (line 149) | def load_data(self):
    method make_projectors (line 196) | def make_projectors(self):
  class TrnData (line 227) | class TrnData(data.Dataset):
    method __init__ (line 228) | def __init__(self, coomat):
    method neg_sampling (line 238) | def neg_sampling(self):
    method __len__ (line 250) | def __len__(self):
    method __getitem__ (line 253) | def __getitem__(self, idx):
  class NodeTstData (line 256) | class NodeTstData(data.Dataset):
    method __init__ (line 257) | def __init__(self, tst_mat):
    method __len__ (line 261) | def __len__(self):
    method __getitem__ (line 264) | def __getitem__(self, idx):
  class JointTrnData (line 267) | class JointTrnData(data.Dataset):
    method __init__ (line 268) | def __init__(self, dataset_list):
    method neg_sampling (line 280) | def neg_sampling(self):
    method __len__ (line 284) | def __len__(self):
    method __getitem__ (line 287) | def __getitem__(self, idx):

FILE: node_classification/main.py
  class Exp (line 15) | class Exp:
    method __init__ (line 16) | def __init__(self, multi_handler):
    method make_print (line 32) | def make_print(self, name, ep, reses, save, data_name=None):
    method run (line 46) | def run(self):
    method print_model_size (line 124) | def print_model_size(self):
    method prepare_model (line 139) | def prepare_model(self):
    method train_epoch (line 144) | def train_epoch(self):
    method make_trn_masks (line 196) | def make_trn_masks(self, numpy_usrs, csr_mat):
    method test_loss_epoch (line 202) | def test_loss_epoch(self, handler, dataset_id):
    method test_epoch (line 234) | def test_epoch(self, handler, dataset_id):
    method calc_recall_ndcg (line 263) | def calc_recall_ndcg(self, topLocs, tstLocs, batIds):
    method save_history (line 282) | def save_history(self):
    method load_model (line 294) | def load_model(self):

FILE: node_classification/model.py
  class FeedForwardLayer (line 14) | class FeedForwardLayer(nn.Module):
    method __init__ (line 15) | def __init__(self, in_feat, out_feat, bias=True, act=None):
    method forward (line 29) | def forward(self, embeds):
  class TopoEncoder (line 34) | class TopoEncoder(nn.Module):
    method __init__ (line 35) | def __init__(self):
    method forward (line 40) | def forward(self, adj, embeds, normed=False):
  class MLP (line 56) | class MLP(nn.Module):
    method __init__ (line 57) | def __init__(self):
    method forward (line 63) | def forward(self, embeds):
  class GTLayer (line 68) | class GTLayer(nn.Module):
    method __init__ (line 69) | def __init__(self):
    method _pick_anchors (line 77) | def _pick_anchors(self, embeds):
    method forward (line 82) | def forward(self, embeds):
  class GraphTransformer (line 92) | class GraphTransformer(nn.Module):
    method __init__ (line 93) | def __init__(self):
    method forward (line 97) | def forward(self, embeds):
  class Feat_Projector (line 102) | class Feat_Projector(nn.Module):
    method __init__ (line 103) | def __init__(self, feats):
    method svd_proj (line 117) | def svd_proj(self, feats):
    method uniform_proj (line 128) | def uniform_proj(self, feats):
    method random_proj (line 132) | def random_proj(self, feats):
    method forward (line 136) | def forward(self):
  class Adj_Projector (line 139) | class Adj_Projector(nn.Module):
    method __init__ (line 140) | def __init__(self, adj):
    method svd_proj (line 147) | def svd_proj(self, adj):
    method forward (line 165) | def forward(self):
  class Expert (line 168) | class Expert(nn.Module):
    method __init__ (line 169) | def __init__(self):
    method forward (line 179) | def forward(self, projectors, pck_nodes=None):
    method pred_norm (line 186) | def pred_norm(self, pos_preds, neg_preds):
    method cal_loss (line 195) | def cal_loss(self, batch_data, projectors):
    method pred_for_test (line 225) | def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed...
    method pred_for_node_test (line 252) | def pred_for_node_test(self, nodes, cand_size, feats, rerun_embed=True):
    method attempt (line 262) | def attempt(self, topo_embeds, dataset):
  class AnyGraph (line 288) | class AnyGraph(nn.Module):
    method __init__ (line 289) | def __init__(self):
    method one_graph_one_expert (line 294) | def one_graph_one_expert(self, handlers, log_assignment=False):
    method store_history_assignment (line 301) | def store_history_assignment(self, assignment):
    method assign_experts (line 304) | def assign_experts(self, handlers, reca=True, log_assignment=False):
    method summon (line 345) | def summon(self, dataset_id):
    method summon_opt (line 348) | def summon_opt(self, dataset_id):

FILE: node_classification/params.py
  function parse_args (line 3) | def parse_args():

FILE: params.py
  function parse_args (line 3) | def parse_args():
Condensed preview — 15 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (110K chars).
[
  {
    "path": ".gitignore",
    "chars": 35,
    "preview": "extract_code_structure.py\n.DS_Store"
  },
  {
    "path": "Models/README.md",
    "chars": 188,
    "preview": "Download the pre-trained AnyGraph models from <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efm"
  },
  {
    "path": "README.md",
    "chars": 9527,
    "preview": "<h1 align='center'>AnyGraph: Graph Foundation Model in the Wild</h1>\n\n<div align='center'>\n<a href='https://arxiv.org/pd"
  },
  {
    "path": "Utils/TimeLogger.py",
    "chars": 476,
    "preview": "import datetime\n\nlogmsg = ''\ntimemark = dict()\nsaveDefault = False\ndef log(msg, save=None, oneline=False):\n\tglobal logms"
  },
  {
    "path": "data_handler.py",
    "chars": 13207,
    "preview": "import pickle\nimport numpy as np\nfrom scipy.sparse import csr_matrix, coo_matrix, dok_matrix\nfrom params import args\nimp"
  },
  {
    "path": "main.py",
    "chars": 15122,
    "preview": "import torch as t\nfrom torch import nn\nimport Utils.TimeLogger as logger\nfrom Utils.TimeLogger import log\nfrom params im"
  },
  {
    "path": "model.py",
    "chars": 14298,
    "preview": "import torch as t\nfrom torch import nn\nimport torch.nn.functional as F\nfrom params import args\nimport numpy as np\nfrom U"
  },
  {
    "path": "node_classification/Utils/TimeLogger.py",
    "chars": 476,
    "preview": "import datetime\n\nlogmsg = ''\ntimemark = dict()\nsaveDefault = False\ndef log(msg, save=None, oneline=False):\n\tglobal logms"
  },
  {
    "path": "node_classification/data_handler.py",
    "chars": 12729,
    "preview": "import pickle\nimport numpy as np\nfrom scipy.sparse import csr_matrix, coo_matrix, dok_matrix\nfrom params import args\nimp"
  },
  {
    "path": "node_classification/main.py",
    "chars": 17196,
    "preview": "import torch as t\nfrom torch import nn\nimport Utils.TimeLogger as logger\nfrom Utils.TimeLogger import log\nfrom params im"
  },
  {
    "path": "node_classification/model.py",
    "chars": 15751,
    "preview": "import torch as t\nfrom torch import nn\nimport torch.nn.functional as F\nfrom params import args\nimport numpy as np\nfrom U"
  },
  {
    "path": "node_classification/params.py",
    "chars": 3986,
    "preview": "import argparse\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Model Parameters')\n    parser.add_a"
  },
  {
    "path": "params.py",
    "chars": 3714,
    "preview": "import argparse\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Model Parameters')\n    parser.add_a"
  }
]

// ... and 2 more files (download for full content)

About this extraction

This page contains the full source code of the HKUDS/AnyGraph GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 15 files (104.2 KB), approximately 27.4k tokens, and a symbol index with 166 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!