[
  {
    "path": ".gitignore",
    "content": "extract_code_structure.py\n.DS_Store"
  },
  {
    "path": "Models/README.md",
    "content": "Download the pre-trained AnyGraph models from <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efmm5TJm0B5EnmYzTqg8GWEB1loKzeIR5tcr3hPIOJDXXA?e=2wMgZC'>this link</a>."
  },
  {
    "path": "README.md",
    "content": "<h1 align='center'>AnyGraph: Graph Foundation Model in the Wild</h1>\n\n<div align='center'>\n<a href='https://arxiv.org/pdf/2408.10700'><img src='https://img.shields.io/badge/Paper-PDF-green'></a>\n<!-- <a href=''><img src='https://img.shields.io/badge/公众号-blue' /></a> -->\n<!-- <a href=''><img src='https://img.shields.io/badge/CSDN-orange' /></a> -->\n<img src=\"https://badges.pufler.dev/visits/hkuds/anygraph?style=flat-square&logo=github\">\n<img src='https://img.shields.io/github/stars/hkuds/anygraph?color=green&style=social' />\n\n<a href='https://akaxlh.github.io/'>Lianghao Xia</a> and <a href='https://sites.google.com/view/chaoh/group-join-us'>Chao Huang</a>\n\n**Introducing AnyGraph, a graph foundation model designed for zero-shot predictions across domains.**\n\n<img src='imgs/article cover.png' />\n\n</div>\n\n**Objectives of AnyGraph:**\n\n- **Structure Heterogeneity**: Addressing distribution shift in graph structural information.\n- **Feature Heterogeneity**: Handling diverse feature representation spaces across graph datasets.\n- **Fast Adaptation**: Efficiently adapting the model to new graph domains.\n- **Scaling Law Emergence**: Performance scales with the amount of data and model parameters.\n\n<br>\n\n**Key Features of AnyGraph:**\n\n- **Graph Mixture-of-Experts (MoE)**: Effectively addresses cross-domain heterogeneity using an array of expert models.\n- **Lightweight Graph Expert Routing Mechanism**: Enables swift adaptation to new datasets and domains.\n- **Adaptive and Efficient Graph Experts**: Custom-designed to handle graphs with a wide range of structural patterns and feature spaces.\n- **Extensively Trained and Tested**: Exhibits strong generalizability over 38 diverse graph datasets, showcasing scaling laws and emergent capabilities.\n\n<img src='imgs/framework_final.jpeg' />\n\n## News\n- [x] 2024/09/18 Fix minor bugs on loading and saving path, class names, etc.\n- [x] 2024/09/11 Datasets are available now!\n- [x] 2024/09/11 Pre-trained models are available on huggingface now!\n- [x] 2024/08/20 Paper is released.\n- [x] 2024/08/20 Code is released.\n\n## Environment Setup\nDownload the data files at <a href='https://huggingface.co/datasets/hkuds/AnyGraph_datasets'>this link</a>. And fill in your own directories for data storage at function `get_data_files(self)` of class `DataHandler` in the file `data_handler.py`.\n\nDownload the pre-trained AnyGraph models at <a href='https://huggingface.co/hkuds/AnyGraph/'>hugging face</a> or <a href='https://hkuhk-my.sharepoint.com/:u:/g/personal/lhaoxia_hku_hk/Efmm5TJm0B5EnmYzTqg8GWEB1loKzeIR5tcr3hPIOJDXXA?e=2wMgZC'>one drive</a>, and put it into `Models/`.\n\n**Packages**: Our experiments were conducted with the following package versions:\n* python==3.10.13\n* torch==1.13.0\n* numpy==1.23.4\n* scipy==1.9.3\n\n**Device Requirements**: The training and testing of AnyGraph requires only one GPU with 24G memory (e.g. 3090, 4090). Using larger input graphs may require devices with larger memory.\n\n## Code Structure\nHere is a brief overview of the code structures. The explanations for each directory are enclosed in quotes (##...##).\n```\n./\n│   ├── .gitignore\n│   ├── README.md\n│   ├── data_handler.py ## load and process data\n│   ├── model.py ## implementation for the model\n│   ├── params.py ## hyperparameters\n│   └── main.py ## main file for pretraining and link prediction\n│   ├── History/ ## training and testing logs\n│   │   ├── pretrain_link1.his\n│   │   └── pretrain_link2.his\n│   ├── Models/ ## pre-trained models\n│   │   └── README.md\n│   ├── Utils/ ## utility function\n│   │   └── TimeLogger.py\n│   ├── imgs/ ## images used in readme\n│   │   ├── ablation.png\n│   │   ├── article cover.png\n│   │   ├── datasets.png\n│   │   ├── framework_final.jpeg\n│   │   ├── framework_final.pdf\n│   │   ├── framework_final.png\n│   │   ├── overall_performance2.png\n│   │   ├── routing.png\n│   │   ├── scaling_law.png\n│   │   ├── training_time.png\n│   │   ├── tuning_steps.png\n│   │   └── overall_performance1.png\n│   ├── node_classification/ ## test code for node classification\n│   │   ├── data_handler.py\n│   │   ├── model.py\n│   │   ├── params.py\n│   │   └── main.py\n│   │   ├── Utils/\n│   │   │   └── TimeLogger.py\n```\n\n## Usage\nTo reproduce the test performance reported in the paper, run the following command lines:\n```\n# Test on Link2 and Link1 data, respectively\npython main.py --load pretrain_link1 --epoch 0 --dataset link2 \npython main.py --load pretrain_link2 --epoch 0 --dataset link1\n\n# Test on the Ecommerce datasets in the Link2 and Link1 group, respectively.\n# Testing on Academic and Others datasets are conducted similarily.\npython main.py --load pretrain_link1 --epoch 0 --dataset ecommerce_in_link2\npython main.py --load pretrain_link2 --epoch 0 --dataset ecommerce_in_link1\n\n# Test the performance for node classification datasets\ncd ./node_classification\npython main.py --load pretrain_link2 --epoch 0 --dataset node\n```\n\nTo re-train the two models by yourself, run:\n```\npython main.py --dataset link2+link1 --save pretrain_link2\npython main.py --dataset link1+link2 --save pretrain_link1\n```\n\n## Datasets\n\n<img src='imgs/datasets.png' />\n\nThe statistics for the experimental datasets are presented in the table above. We categorize them into distinct groups as below. Note that Link1 and Link2 include datasets from different sources, and the datasets do not share the same feature spaces. This separation ensures a robust evaluation of true zero-shot performance in graph prediction tasks.\n\n| Group | Included Datasets |\n| ----- | ----- |\n| Link1 | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, P2P-Gnutella06, Soc-Epinions1, Email-Enron |\n| Link2 | Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla, Arxiv, Arxiv-t, Cora, CS, OGB-Collab, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA |\n| Ecommerce | Products-tech, Yelp2018, Yelp-textfeat, Products-home, Steam-text, Amazon-text, Amazon-book, Photo, Goodreads, Fitness, Movielens-1M, Movielens10M, Gowalla |\n| Academic | Citation-2019, Citation-20Century, Pubmed-link, Citeseer, OGB-PPA, Arxiv, Arxiv-t, Cora, CS, OGB-Collab |\n| Others | P2P-Gnutella06, Soc-Epinions1, Email-Enron, Proteins-0, Proteins- 1, Proteins-2, Proteins-3, OGB-DDI, Web-Stanford, RoadNet-PA |\n| Node | Cora, Arxiv, Pubmed, Home, Tech |\n\n\n## Experiments\n\n### Model Pre-Training Curves\nWe present the training logs with respect to epochs below. Each figure contains two curves, each corresponding to two instances of repeated pre-training.\n\n- pretrain_link1\n\n<img src='imgs/link1_loss_curve.png' width=32%/>&nbsp;\n<img src='imgs/link1_fullshot_ndcg_curve.png' width=32%/>&nbsp;\n<img src='imgs/link1_zeroshot_ndcg_curve.png' width=32%/>\n\n- pretrain_link2\n\n<img src='imgs/link2_loss_curve.png' width=32%/>&nbsp;\n<img src='imgs/link2_fullshot_ndcg_curve.png' width=32%/>&nbsp;\n<img src='imgs/link2_zeroshot_ndcg_curve.png' width=32%/>\n\n### Overall Performance Comparison\n\n- Comparing to few-shot end2end models and pre-training and fine-tuning methods.\n![](imgs/overall_performance1.png)\n\n- Comparing to zero-shot graph foundation models.\n<img src='imgs/overall_performance2.png' width=60%/>\n\n### Scaling Law of AnyGraph\n\nWe explore the scaling law of AnyGraph by evaluating 1) model performance v.s. the number of model parameters, and 2) model performance v.s. the number of training samples. \n\nBelow shows the evaluation results on \n- all datasets across domains (a)\n- academic datasets (b)\n- ecommerce datasets (c)\n- other datasets (d) \n\nIn each subfigure, we show \n- zero-shot performance on unseen datasets w.r.t. the amount of model parameters (left)\n- full-shot performance on training datasets w.r.t. the amount of model parameters (middle)\n- zero-shot performance w.r.t. the amount of training data (right)\n\n![](imgs/scaling_law.png)\n\nThe outcome outlines the following key observations: (see Sec. 4.3 for details)\n- Generalizability of AnyGraph Follows the Scaling Law.\n- Emergent Abilities of AnyGraph.\n- Insufficient Training Data May Bring Bias.\n\n### Ablation Study\n\nThe ablation study investigates the impact of the following modules:\n- The overall MoE architecture\n- Frequency regularization in the expert routing mechanism\n- Graph augmentation in the learning process\n- The utilization of (heterogeneous) node features from different datasets\n  \n<img src='imgs/ablation.png' width=60% />\n\n\n### Expert Routing Mechanism\nWe visualize the competence scores between datasets and experts, given by the routing algorithm of AnyGraph. \n\nThe resulting scores below demonstrates the underlying relatedness between different datasets, thus demonstrating the intuitive effectiveness of the routing mechanism. (see Sec. 4.5 for details)\n\n<img src='imgs/routing.png' width=60% />\n\n\n### Fast Adaptation of AnyGraph\n\nWe study the fast adaptation abilities of AnyGraph from two aspects:\n- When fine-tuned on unseen datasets, AnyGraph achieves better performance with less training steps. (Fig. 6 below)\n- The training time of AnyGraph is comparative to that of other methods. (Table 3 below)\n\n<img src='imgs/tuning_steps.png' width=60% />\n\n<img src='imgs/training_time.png' width=60% />\n\n## Citation\n\nIf you find our work useful, please consider citing our paper:\n```\n@article{xia2024anygraph,\n  title={AnyGraph: Graph Foundation Model in the Wild},\n  author={Xia, Lianghao and Huang, Chao},\n  journal={arXiv preprint arXiv:2408.10700},\n  year={2024}\n}\n```\n"
  },
  {
    "path": "Utils/TimeLogger.py",
    "content": "import datetime\n\nlogmsg = ''\ntimemark = dict()\nsaveDefault = False\ndef log(msg, save=None, oneline=False):\n\tglobal logmsg\n\tglobal saveDefault\n\ttime = datetime.datetime.now()\n\ttem = '%s: %s' % (time, msg)\n\tif save != None:\n\t\tif save:\n\t\t\tlogmsg += tem + '\\n'\n\telif saveDefault:\n\t\tlogmsg += tem + '\\n'\n\tif oneline:\n\t\tprint(tem, end='\\r')\n\telse:\n\t\tprint(tem)\n\ndef marktime(marker):\n\tglobal timemark\n\ttimemark[marker] = datetime.datetime.now()\n\n\nif __name__ == '__main__':\n\tlog('')"
  },
  {
    "path": "data_handler.py",
    "content": "import pickle\nimport numpy as np\nfrom scipy.sparse import csr_matrix, coo_matrix, dok_matrix\nfrom params import args\nimport scipy.sparse as sp\nfrom Utils.TimeLogger import log\nimport torch as t\nimport torch.utils.data as data\nimport torch_geometric.transforms as T\nfrom model import Feat_Projector, Adj_Projector, TopoEncoder\nimport os\n\nclass MultiDataHandler:\n    def __init__(self, trn_datasets, tst_datasets_group):\n        all_datasets = trn_datasets\n        all_tst_datasets = []\n        for tst_datasets in tst_datasets_group:\n            all_datasets = all_datasets + tst_datasets\n            all_tst_datasets += tst_datasets\n        all_datasets = list(set(all_datasets))\n        all_datasets.sort()\n        self.trn_handlers = []\n        self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))]\n        for data_name in all_datasets:\n            trn_flag = data_name in trn_datasets\n            tst_flag = data_name in all_tst_datasets\n            handler = DataHandler(data_name)\n            if trn_flag:\n                self.trn_handlers.append(handler)\n            if tst_flag:\n                for i in range(len(tst_datasets_group)):\n                    if data_name in tst_datasets_group[i]:\n                        self.tst_handlers_group[i].append(handler)\n        self.make_joint_trn_loader()\n    \n    def make_joint_trn_loader(self):\n        loader_datasets = []\n        for trn_handler in self.trn_handlers:\n            tem_dataset = trn_handler.trn_loader.dataset\n            loader_datasets.append(tem_dataset)\n        joint_dataset = JointTrnData(loader_datasets)\n        self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0)\n    \n    def remake_initial_projections(self):\n        for i in range(len(self.trn_handlers)):\n            trn_handler = self.trn_handlers[i]\n            trn_handler.make_projectors()\n\nclass DataHandler:\n    def __init__(self, data_name):\n        self.data_name = data_name\n        self.get_data_files()\n        log(f'Loading dataset {data_name}')\n        self.topo_encoder = TopoEncoder()\n        self.load_data()\n    \n    def get_data_files(self):\n        predir = f'/home/user_name/data/zero-shot datasets/{self.data_name}/'\n        if os.path.exists(predir + 'feats.pkl'):\n            self.feat_file = predir + 'feats.pkl'\n        else:\n            self.feat_file = None\n        self.trnfile = predir + 'trn_mat.pkl'\n        self.tstfile = predir + 'tst_mat.pkl'\n        self.fewshotfile = predir + 'partial_mat_{shot}.pkl'.format(shot=args.ratio_fewshot)\n        self.valfile = predir + 'val_mat.pkl'\n\n    def load_one_file(self, filename):\n        with open(filename, 'rb') as fs:\n            ret = (pickle.load(fs) != 0).astype(np.float32)\n        if type(ret) != coo_matrix:\n            ret = sp.coo_matrix(ret)\n        return ret\n    \n    def load_feats(self, filename):\n        try:\n            with open(filename, 'rb') as fs:\n                feats = pickle.load(fs)\n        except Exception as e:\n            print(filename + str(e))\n            exit()\n        return feats\n\n    def normalize_adj(self, mat, log=False):\n        degree = np.array(mat.sum(axis=-1))\n        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])\n        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0\n        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)\n        if mat.shape[0] == mat.shape[1]:\n            return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()\n        else:\n            tem = d_inv_sqrt_mat.dot(mat)\n            col_degree = np.array(mat.sum(axis=0))\n            d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])\n            d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0\n            d_inv_sqrt_mat = sp.diags(d_inv_sqrt)\n            return tem.dot(d_inv_sqrt_mat).tocoo()\n    \n    def unique_numpy(self, row, col):\n        hash_vals = row * args.node_num + col\n        hash_vals = np.unique(hash_vals).astype(np.int64)\n        col = hash_vals % args.node_num\n        row = (hash_vals - col).astype(np.int64) // args.node_num\n        return row, col\n\n    def make_torch_adj(self, mat, unidirectional_for_asym=False):\n        if mat.shape[0] == mat.shape[1]:\n            _row = mat.row\n            _col = mat.col\n            row = np.concatenate([_row, _col]).astype(np.int64)\n            col = np.concatenate([_col, _row]).astype(np.int64)\n            row, col = self.unique_numpy(row, col)\n            data = np.ones_like(row)\n            mat = coo_matrix((data, (row, col)), mat.shape)\n            if args.selfloop == 1:\n                mat = (mat + sp.eye(mat.shape[0])) * 1.0\n            normed_asym_mat = self.normalize_adj(mat)\n            row = t.from_numpy(normed_asym_mat.row).long()\n            col = t.from_numpy(normed_asym_mat.col).long()\n            idxs = t.stack([row, col], dim=0)\n            vals = t.from_numpy(normed_asym_mat.data).float()\n            shape = t.Size(normed_asym_mat.shape)\n            asym_adj = t.sparse.FloatTensor(idxs, vals, shape)\n            return asym_adj\n        elif unidirectional_for_asym:\n            mat = (mat != 0) * 1.0\n            mat = self.normalize_adj(mat, log=True)\n            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))\n            vals = t.from_numpy(mat.data.astype(np.float32))\n            shape = t.Size(mat.shape)\n            return t.sparse.FloatTensor(idxs, vals, shape)\n        else:\n            # make ui adj\n            a = sp.csr_matrix((args.user_num, args.user_num))\n            b = sp.csr_matrix((args.item_num, args.item_num))\n            mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])\n            mat = (mat != 0) * 1.0\n            if args.selfloop == 1:\n                mat = (mat + sp.eye(mat.shape[0])) * 1.0\n            mat = self.normalize_adj(mat)\n\n            # make cuda tensor\n            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))\n            vals = t.from_numpy(mat.data.astype(np.float32))\n            shape = t.Size(mat.shape)\n            return t.sparse.FloatTensor(idxs, vals, shape)\n\n    def load_data(self):\n        tst_mat = self.load_one_file(self.tstfile)\n        val_mat = self.load_one_file(self.valfile)\n        trn_mat = self.load_one_file(self.trnfile)\n        fewshot_mat = self.load_one_file(self.fewshotfile)\n        if self.feat_file is not None:\n            self.feats = t.from_numpy(self.load_feats(self.feat_file)).float()\n            self.feats = self.feats\n            args.featdim = self.feats.shape[1]\n        else:\n            self.feats = None\n            args.featdim = args.latdim\n\n        if trn_mat.shape[0] != trn_mat.shape[1]:\n            args.user_num, args.item_num = trn_mat.shape\n            args.node_num = args.user_num + args.item_num\n            print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz))\n        else:\n            args.node_num = trn_mat.shape[0]\n            print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+val_mat.nnz+tst_mat.nnz))\n        if args.tst_mode == 'tst':\n            tst_data = TstData(tst_mat, trn_mat)\n            self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)\n            self.tst_input_adj = self.make_torch_adj(trn_mat)\n        elif args.tst_mode == 'val':\n            tst_data = TstData(val_mat, trn_mat)\n            self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)\n            self.tst_input_adj = self.make_torch_adj(fewshot_mat)\n        else:\n            raise Exception('Specify proper test mode')\n\n        if args.trn_mode == 'fewshot':\n            self.trn_mat = fewshot_mat\n            trn_data = TrnData(self.trn_mat)\n            self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)\n            self.trn_input_adj = self.make_torch_adj(fewshot_mat)\n            if args.tst_mode == 'val':\n                self.trn_input_adj = self.tst_input_adj\n            else:\n                self.trn_input_adj = self.make_torch_adj(fewshot_mat)\n        elif args.trn_mode == 'train-all':\n            self.trn_mat = trn_mat\n            trn_data = TrnData(self.trn_mat)\n            self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)\n            if args.tst_mode == 'tst':\n                self.trn_input_adj = self.tst_input_adj\n            else:\n                self.trn_input_adj = self.make_torch_adj(trn_mat)\n        else:\n            raise Exception('Specify proper train mode')   \n\n        if self.trn_mat.shape[0] == self.trn_mat.shape[1]:\n            self.asym_adj = self.trn_input_adj\n        else:\n            self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True)\n        self.make_projectors()\n        self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps)\n        self.ratio_500_all = 500 / len(self.trn_loader)\n    \n    def make_projectors(self):\n        with t.no_grad():\n            projectors = []\n            if args.proj_method == 'adj_svd' or args.proj_method == 'both':\n                tem = self.asym_adj.to(args.devices[0])\n                projectors = [Adj_Projector(tem)]\n            if self.feats is not None and args.proj_method != 'adj_svd':\n                tem = self.feats.to(args.devices[0])\n                projectors.append(Feat_Projector(tem))\n            assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot'\n            feats = projectors[0]()\n            if len(projectors) == 2:\n                feats2 = projectors[1]()\n                feats = feats + feats2\n\n            try:\n                self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu()\n            except Exception:\n                print(f'{self.data_name} memory overflow')\n                mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True)\n                tem_adj = self.trn_input_adj.to(args.devices[0])\n                mem_cache = 256\n                projectors_list = []\n                for i in range(feats.shape[1] // mem_cache):\n                    st, ed = i * mem_cache, (i + 1) * mem_cache\n                    tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8)\n                    tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu()\n                    projectors_list.append(tem_feats)\n                self.projectors = t.concat(projectors_list, dim=-1)\n            t.cuda.empty_cache()\n\nclass TstData(data.Dataset):\n    def __init__(self, coomat, trn_mat):\n        self.csrmat = (trn_mat.tocsr() != 0) * 1.0\n        tstLocs = [None] * coomat.shape[0]\n        tst_nodes = set()\n        for i in range(len(coomat.data)):\n            row = coomat.row[i]\n            col = coomat.col[i]\n            if tstLocs[row] is None:\n                tstLocs[row] = list()\n            tstLocs[row].append(col)\n            tst_nodes.add(row)\n        tst_nodes = np.array(list(tst_nodes))\n        self.tst_nodes = tst_nodes\n        self.tstLocs = tstLocs\n\n    def __len__(self):\n        return len(self.tst_nodes)\n\n    def __getitem__(self, idx):\n        return self.tst_nodes[idx]\n\nclass TrnData(data.Dataset):\n    def __init__(self, coomat):\n        self.ancs, self.poss = coomat.row, coomat.col\n        self.negs = np.zeros(len(self.ancs)).astype(np.int32)\n        self.cand_num = coomat.shape[1]\n        self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0]\n        self.poss = coomat.col + self.neg_shift\n        self.neg_sampling()\n    \n    def neg_sampling(self):\n        self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0])\n\n    def __len__(self):\n        return len(self.ancs)\n    \n    def __getitem__(self, idx):\n        return self.ancs[idx], self.poss[idx] , self.negs[idx]\n\nclass JointTrnData(data.Dataset):\n    def __init__(self, dataset_list):\n        self.batch_dataset_ids = []\n        self.batch_st_ed_list = []\n        self.dataset_list = dataset_list\n        for dataset_id, dataset in enumerate(dataset_list):\n            samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0)\n            for j in range(samp_num):\n                self.batch_dataset_ids.append(dataset_id)\n                st = j * args.batch\n                ed = min((j + 1) * args.batch, len(dataset))\n                self.batch_st_ed_list.append((st, ed))\n    \n    def neg_sampling(self):\n        for dataset in self.dataset_list:\n            dataset.neg_sampling()\n\n    def __len__(self):\n        return len(self.batch_dataset_ids)\n    \n    def __getitem__(self, idx):\n        st, ed = self.batch_st_ed_list[idx]\n        dataset_id = self.batch_dataset_ids[idx]\n        return *self.dataset_list[dataset_id][st: ed], dataset_id"
  },
  {
    "path": "main.py",
    "content": "import torch as t\nfrom torch import nn\nimport Utils.TimeLogger as logger\nfrom Utils.TimeLogger import log\nfrom params import args\nfrom model import Expert, Feat_Projector, Adj_Projector, AnyGraph\nfrom data_handler import MultiDataHandler, DataHandler\nimport numpy as np\nimport pickle\nimport os\nimport setproctitle\nimport time\n\nclass Exp:\n    def __init__(self, multi_handler):\n        self.multi_handler = multi_handler\n        print(list(map(lambda x: x.data_name, multi_handler.trn_handlers)))\n        for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group):\n            print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers)))\n        self.metrics = dict()\n        trn_mets = ['Loss', 'preLoss']\n        tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss']\n        mets = trn_mets + tst_mets\n        for met in mets:\n            if met in trn_mets:\n                self.metrics['Train' + met] = list()\n            if met in tst_mets:\n                for i in range(len(self.multi_handler.tst_handlers_group)):\n                    self.metrics['Test' + str(i) + met] = list()\n        \n    def make_print(self, name, ep, reses, save, data_name=None):\n        if data_name is None:\n            ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)\n        else:\n            ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)\n        for metric in reses:\n            val = reses[metric]\n            ret += '%s = %.4f, ' % (metric, val)\n            tem = name + metric if data_name is None else name + data_name + metric\n            if save and tem in self.metrics:\n                self.metrics[tem].append(val)\n        ret = ret[:-2] + '      '\n        return ret\n    \n    def run(self):\n        self.prepare_model()\n        log('Model Prepared')\n        stloc = 0\n        if args.load_model != None:\n            self.load_model()\n            stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)\n        best_ndcg, best_ep = 0, -1\n        for ep in range(stloc, args.epoch):\n            tst_flag = (ep % args.tst_epoch == 0)\n            start_time = time.time()\n            self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True)\n            reses = self.train_epoch()\n            log(self.make_print('Train', ep, reses, tst_flag))\n            self.multi_handler.remake_initial_projections()\n            end_time = time.time()\n            print(f'NOTICE: {end_time-start_time}')\n            if tst_flag:\n                for handler_group_id in range(len(self.multi_handler.tst_handlers_group)):\n                    tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id]\n                    self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)\n                    recall, ndcg, tstnum = 0, 0, 0\n                    for i, handler in enumerate(tst_handlers):\n                        reses = self.test_epoch(handler, i)\n                        # log(self.make_print(f'{handler.data_name}', ep, reses, False))\n                        recall += reses['Recall'] * reses['tstNum']\n                        ndcg += reses['NDCG'] * reses['tstNum']\n                        tstnum += reses['tstNum']\n                    reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum}\n                    log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag))\n\n                    if reses['NDCG'] > best_ndcg:\n                        best_ndcg = reses['NDCG']\n                        best_ep = ep\n                self.save_history()\n            print()\n\n        for test_group_id in range(len(self.multi_handler.tst_handlers_group)):\n            repeat_times = 5\n            overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times)\n            overall_tstnum = 0\n            tst_handlers = self.multi_handler.tst_handlers_group[test_group_id]\n            for i, handler in enumerate(tst_handlers):\n                for topk in [args.topk]:\n                    args.topk = topk\n                    mets = dict()\n                    for _ in range(repeat_times):\n                        handler.make_projectors()\n                        self.model.assign_experts([handler], reca=False, log_assignment=False)\n                        reses = self.test_epoch(handler, 0)\n                        for met in reses:\n                            if met not in mets:\n                                mets[met] = []\n                            mets[met].append(reses[met])\n                    tstnum = reses['tstNum']\n                    tot_reses = dict()\n                    for met in reses:\n                        tem_arr = np.array(mets[met])\n                        tot_reses[met + '_std'] = tem_arr.std()\n                        tot_reses[met + '_mean'] = tem_arr.mean()\n                    if topk == args.topk:\n                        overall_recall += np.array(mets['Recall']) * tstnum\n                        overall_ndcg += np.array(mets['NDCG']) * tstnum\n                        overall_tstnum += tstnum\n                    log(self.make_print(f'Test Top-{topk}', args.epoch, tot_reses, False, handler.data_name))\n            overall_recall /= overall_tstnum\n            overall_ndcg /= overall_tstnum\n            overall_res = dict()\n            overall_res['Recall_mean'] = overall_recall.mean()\n            overall_res['Recall_std'] = overall_recall.std()\n            overall_res['NDCG_mean'] = overall_ndcg.mean()\n            overall_res['NDCG_std'] = overall_ndcg.std()\n            log(self.make_print('Overall Test', args.epoch, overall_res, False))\n        self.save_history()\n\n    def print_model_size(self):\n        total_params = 0\n        trainable_params = 0\n        non_trainable_params = 0\n        for param in self.model.parameters():\n            tem = np.prod(param.size())\n            total_params += tem\n            if param.requires_grad:\n                trainable_params += tem\n            else:\n                non_trainable_params += tem\n        print(f'Total params: {total_params/1e6}')\n        print(f'Trainable params: {trainable_params/1e6}')\n        print(f'Non-trainable params: {non_trainable_params/1e6}')\n\n    def prepare_model(self):\n        self.model = AnyGraph()\n        t.cuda.empty_cache()\n        self.print_model_size()\n\n    def train_epoch(self):\n        self.model.train()\n        trn_loader = self.multi_handler.joint_trn_loader\n        trn_loader.dataset.neg_sampling()\n        ep_loss, ep_preloss, ep_regloss = 0, 0, 0\n        steps = len(trn_loader)\n        tot_samp_num = 0\n        counter = [0] * len(self.multi_handler.trn_handlers)\n        reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers)))\n        for i, batch_data in enumerate(trn_loader):\n            if args.epoch_max_step > 0 and i >= args.epoch_max_step:\n                break\n            ancs, poss, negs, dataset_id = batch_data\n            ancs = ancs[0].long()\n            poss = poss[0].long()\n            negs = negs[0].long()\n            dataset_id = dataset_id[0].long()\n            tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all\n            if tem_bar < 1.0 and np.random.uniform() > tem_bar:\n                steps -= 1\n                continue\n\n            expert = self.model.summon(dataset_id)\n            opt = self.model.summon_opt(dataset_id)\n            feats = self.multi_handler.trn_handlers[dataset_id].projectors\n            loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\n            sample_num = ancs.shape[0]\n            tot_samp_num += sample_num\n            ep_loss += loss.item() * sample_num\n            ep_preloss += loss_dict['preloss'].item() * sample_num\n            ep_regloss += loss_dict['regloss'].item()\n            log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f        ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)\n\n            counter[dataset_id] += 1\n            if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0:\n                self.multi_handler.trn_handlers[dataset_id].make_projectors()\n            if (i + 1) % reassign_steps == 0:\n                self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False)\n        ret = dict()\n        ret['Loss'] = ep_loss / tot_samp_num\n        ret['preLoss'] = ep_preloss / tot_samp_num\n        ret['regLoss'] = ep_regloss / steps\n        t.cuda.empty_cache()\n        return ret\n    \n    def make_trn_masks(self, numpy_usrs, csr_mat):\n        trn_masks = csr_mat[numpy_usrs].tocoo()\n        cand_size = trn_masks.shape[1]\n        trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long()\n        return trn_masks, cand_size\n\n    def test_epoch(self, handler, dataset_id):\n        with t.no_grad():\n            tst_loader = handler.tst_loader\n            self.model.eval()\n            expert = self.model.summon(dataset_id)\n            ep_recall, ep_ndcg = 0, 0\n            ep_tstnum = len(tst_loader.dataset)\n            steps = max(ep_tstnum // args.tst_batch, 1)\n            for i, batch_data in enumerate(tst_loader):\n                if args.tst_steps != -1 and i > args.tst_steps:\n                    break\n\n                usrs = batch_data.long()\n                trn_masks, cand_size = self.make_trn_masks(batch_data.numpy(), tst_loader.dataset.csrmat)\n                feats = handler.projectors\n                all_preds = expert.pred_for_test((usrs, trn_masks), cand_size, feats, rerun_embed=False if i!=0 else True)\n                _, top_locs = t.topk(all_preds, args.topk)\n                top_locs = top_locs.cpu().numpy()\n                recall, ndcg = self.calc_recall_ndcg(top_locs, tst_loader.dataset.tstLocs, usrs)\n                ep_recall += recall\n                ep_ndcg += ndcg\n                log('Steps %d/%d: recall = %.2f, ndcg = %.2f          ' % (i, steps, recall, ndcg), save=False, oneline=True)\n        ret = dict()\n        if args.tst_steps != -1:\n            ep_tstnum = args.tst_steps * args.tst_batch\n        ret['Recall'] = ep_recall / ep_tstnum\n        ret['NDCG'] = ep_ndcg / ep_tstnum\n        ret['tstNum'] = ep_tstnum\n        t.cuda.empty_cache()\n        return ret\n    \n    def calc_recall_ndcg(self, topLocs, tstLocs, batIds):\n        assert topLocs.shape[0] == len(batIds)\n        allRecall = allNdcg = 0\n        for i in range(len(batIds)):\n            temTopLocs = list(topLocs[i])\n            temTstLocs = tstLocs[batIds[i]]\n            tstNum = len(temTstLocs)\n            maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])\n            recall = dcg = 0\n            for val in temTstLocs:\n                if val in temTopLocs:\n                    recall += 1\n                    dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))\n            recall = recall / tstNum\n            ndcg = dcg / maxDcg\n            allRecall += recall\n            allNdcg += ndcg\n        return allRecall, allNdcg\n    \n    def save_history(self):\n        if args.epoch == 0:\n            return\n        with open('./History/' + args.save_path + '.his', 'wb') as fs:\n            pickle.dump(self.metrics, fs)\n\n        content = {\n            'model': self.model,\n        }\n        t.save(content, './Models/' + args.save_path + '.mod')\n        log('Model Saved: %s' % args.save_path)\n\n    def load_model(self):\n        ckp = t.load('./Models/' + args.load_model + '.mod')\n        self.model = ckp['model']\n        self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)\n\n        with open('./History/' + args.load_model + '.his', 'rb') as fs:\n            self.metrics = pickle.load(fs)\n        log('Model Loaded')\n\nif __name__ == '__main__':\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n    if len(args.gpu.split(',')) == 2:\n        args.devices = ['cuda:0', 'cuda:1']\n    elif len(args.gpu.split(',')) > 2:\n        raise Exception('Devices should be less than 2')\n    else:\n        args.devices = ['cuda:0', 'cuda:0']\n    logger.saveDefault = True\n    setproctitle.setproctitle('AnyGraph')\n\n    log('Start')\n\n    \n    datasets = dict()\n    datasets['all'] = [\n        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'\n    ]\n    datasets['ecommerce'] = [\n        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech'\n    ]\n    datasets['academic'] = [\n        'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab'\n    ]\n    datasets['others'] = [\n        'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'\n    ]\n    datasets['link1'] = [\n        'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron',\n    ]\n    datasets['link2'] = [\n        'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA',\n    ]\n\n    if args.dataset_setting in datasets.keys():\n        trn_datasets = tst_datasets = datasets[args.dataset_setting]\n    elif args.dataset_setting in datasets['all']:\n        trn_datasets = tst_datasets = [args.dataset_setting]\n    elif '+' in args.dataset_setting:\n        idx = args.dataset_setting.index('+')\n        trn_datasets = datasets[args.dataset_setting[:idx]]\n        tst_datasets = datasets[args.dataset_setting[idx+1:]]\n    elif '_in_' in args.dataset_setting:\n        idx = args.dataset_setting.index('_in_')\n        tst_datasets_1 = datasets[args.dataset_setting[:idx]]\n        tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]]\n        tst_datasets = []\n        for data in tst_datasets_1:\n            if data in tst_datasets_2:\n                tst_datasets.append(data)\n        trn_datasets = tst_datasets\n\n    if '+' not in args.dataset_setting:\n        # No zero-shot prediction test\n        handler = MultiDataHandler(trn_datasets, [tst_datasets])\n    else:\n        handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets])\n    log('Load Data')\n\n    exp = Exp(handler)\n    exp.run()\n    print(args.load_model, args.dataset_setting)\n"
  },
  {
    "path": "model.py",
    "content": "import torch as t\nfrom torch import nn\nimport torch.nn.functional as F\nfrom params import args\nimport numpy as np\nfrom Utils.TimeLogger import log\nfrom torch.nn import MultiheadAttention\nfrom time import time\n\ninit = nn.init.xavier_uniform_\nuniformInit = nn.init.uniform_\n\nclass FeedForwardLayer(nn.Module):\n    def __init__(self, in_feat, out_feat, bias=True, act=None):\n        super(FeedForwardLayer, self).__init__()\n        self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)\n        if act == 'identity' or act is None:\n            self.act = None\n        elif act == 'leaky':\n            self.act = nn.LeakyReLU(negative_slope=args.leaky)\n        elif act == 'relu':\n            self.act = nn.ReLU()\n        elif act == 'relu6':\n            self.act = nn.ReLU6()\n        else:\n            raise Exception('Error')\n    \n    def forward(self, embeds):\n        if self.act is None:\n            return self.linear(embeds)\n        return self.act(self.linear(embeds))\n\nclass TopoEncoder(nn.Module):\n    def __init__(self):\n        super(TopoEncoder, self).__init__()\n\n        self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)\n\n    def forward(self, adj, embeds, normed=False):\n        with t.no_grad():\n            if not normed:\n                embeds = self.layer_norm(embeds)\n            # embeds_list = []\n            final_embeds = 0\n            if args.gnn_layer == 0:\n                final_embeds = embeds\n                # embeds_list.append(embeds)\n            for _ in range(args.gnn_layer):\n                embeds = t.spmm(adj, embeds)\n                final_embeds += embeds\n                # embeds_list.append(embeds)\n            embeds = final_embeds#sum(embeds_list)\n        return embeds\n\nclass MLP(nn.Module):\n    def __init__(self):\n        super(MLP, self).__init__()\n        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)])\n        self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)])\n        self.dropout = nn.Dropout(p=args.drop_rate)\n    \n    def forward(self, embeds):\n        for i in range(args.fc_layer):\n            embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds)\n        return embeds\n\nclass GTLayer(nn.Module):\n    def __init__(self):\n        super(GTLayer, self).__init__()\n        self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)\n        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False\n        self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)\n        self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)\n        self.fc_dropout = nn.Dropout(p=args.drop_rate)\n    \n    def _pick_anchors(self, embeds):\n        perm = t.randperm(embeds.shape[0])\n        anchors = perm[:args.anchor]\n        return embeds[anchors]\n    \n    def forward(self, embeds):\n        anchor_embeds = self._pick_anchors(embeds)\n        _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)\n        anchor_embeds = _anchor_embeds + anchor_embeds\n        _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)\n        embeds = self.layer_norm1(_embeds + embeds)\n        _embeds = self.fc_dropout(self.dense_layers(embeds))\n        embeds = (self.layer_norm2(_embeds + embeds))\n        return embeds\n\nclass GraphTransformer(nn.Module):\n    def __init__(self):\n        super(GraphTransformer, self).__init__()\n        self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])\n\n    def forward(self, embeds):\n        for i, layer in enumerate(self.gt_layers):\n            embeds = layer(embeds) / args.scale_layer\n        return embeds\n\nclass Feat_Projector(nn.Module):\n    def __init__(self, feats):\n        super(Feat_Projector, self).__init__()\n\n        if args.proj_method == 'uniform':\n            self.proj_embeds = self.uniform_proj(feats)\n        elif args.proj_method == 'svd' or args.proj_method == 'both':\n            self.proj_embeds = self.svd_proj(feats)\n        elif args.proj_method == 'random':\n            self.proj_embeds = self.random_proj(feats)\n        elif args.proj_method == 'original':\n            self.proj_embeds = feats\n        self.proj_embeds = t.flip(self.proj_embeds, dims=[-1])\n        self.proj_embeds = self.proj_embeds.detach()\n    \n    def svd_proj(self, feats):\n        if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]:\n            dim = min(feats.shape[0], feats.shape[1])\n            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter)\n            decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)\n            s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])])\n        else:\n            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter)\n        decom_feats = decom_feats @ t.diag(t.sqrt(s))\n        return decom_feats.cpu()\n    \n    def uniform_proj(self, feats):\n        projection = init(t.empty(args.featdim, args.latdim))\n        return feats @ projection\n    \n    def random_proj(self, feats):\n        projection = init(t.empty(feats.shape[0], args.latdim))\n        return projection\n    \n    def forward(self):\n        return self.proj_embeds\n\nclass Adj_Projector(nn.Module):\n    def __init__(self, adj):\n        super(Adj_Projector, self).__init__()\n\n        if args.proj_method == 'adj_svd' or args.proj_method == 'both':\n            self.proj_embeds = self.svd_proj(adj)\n        self.proj_embeds = self.proj_embeds.detach()\n    \n    def svd_proj(self, adj):\n        q = args.latdim\n        if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:\n            dim = min(adj.shape[0], adj.shape[1])\n            svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)\n            svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)\n            svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)\n            s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])\n        else:\n            svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)\n        svd_u = svd_u @ t.diag(t.sqrt(s))\n        svd_v = svd_v @ t.diag(t.sqrt(s))\n        if adj.shape[0] != adj.shape[1]:\n            projection = t.concat([svd_u, svd_v], dim=0)\n        else:\n            projection = svd_u + svd_v\n        return projection.cpu()\n    \n    def forward(self):\n        return self.proj_embeds\n\nclass Expert(nn.Module):\n    def __init__(self):\n        super(Expert, self).__init__()\n        \n        self.topo_encoder = TopoEncoder().to(args.devices[0])\n        if args.nn == 'mlp':\n            self.trainable_nn = MLP().to(args.devices[1])\n        else:\n            self.trainable_nn = GraphTransformer().to(args.devices[1])\n        self.trn_count = 1\n    \n    def forward(self, projectors, pck_nodes=None):\n        embeds = projectors.to(args.devices[1])\n        if pck_nodes is not None:\n            embeds = embeds[pck_nodes]\n        embeds = self.trainable_nn(embeds)\n        return embeds\n\n    def pred_norm(self, pos_preds, neg_preds):\n        pos_preds_num = pos_preds.shape[0]\n        neg_preds_shape = neg_preds.shape\n        preds = t.concat([pos_preds, neg_preds.view(-1)])\n        preds = preds - preds.max()\n        pos_preds = preds[:pos_preds_num]\n        neg_preds = preds[pos_preds_num:].view(neg_preds_shape)\n        return pos_preds, neg_preds\n    \n    def cal_loss(self, batch_data, projectors):\n        ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data))\n        self.trn_count += ancs.shape[0]\n        pck_nodes = t.concat([ancs, poss, negs])\n        final_embeds = self.forward(projectors, pck_nodes)\n        # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]\n        anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3)\n        if final_embeds.isinf().any() or final_embeds.isnan().any():\n            raise Exception('Final embedding fails')\n        \n        if args.loss == 'ce':\n            pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)\n            if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():\n                raise Exception('Preds fails')\n            pos_loss = pos_preds\n            neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()\n            pre_loss = -(pos_loss - neg_loss).mean()\n        elif args.loss == 'bpr':\n            pos_preds = (anc_embeds * pos_embeds).sum(-1)\n            neg_preds = (anc_embeds * neg_embeds).sum(-1)\n            pos_loss, neg_loss = pos_preds, neg_preds\n            pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean() \n\n        if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():\n            raise Exception('NaN or Inf')\n\n        reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))\n        loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}\n        return pre_loss + reg_loss, loss_dict\n    \n    def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True):\n        ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data))\n        if rerun_embed:\n            try:\n                final_embeds = self.forward(projectors)\n            except Exception:\n                final_embeds_list = []\n                div = args.batch * 3\n                temlen = projectors.shape[0] // div\n                for i in range(temlen):\n                    st, ed = div * i, div * (i + 1)\n                    tem_projectors = projectors[st: ed, :]\n                    final_embeds_list.append(self.forward(tem_projectors))\n                if temlen * div < projectors.shape[0]:\n                    tem_projectors = projectors[temlen*div:, :]\n                    final_embeds_list.append(self.forward(tem_projectors))\n                final_embeds = t.concat(final_embeds_list, dim=0)\n            self.final_embeds = final_embeds\n        final_embeds = self.final_embeds\n        anc_embeds = final_embeds[ancs]\n        cand_embeds = final_embeds[-cand_size:]\n\n        mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size]))\n        dense_mat = mask_mat.to_dense()\n        all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8\n        return all_preds\n\n    def attempt(self, topo_embeds, dataset):\n        final_embeds = self.trainable_nn(topo_embeds)\n        rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs]))\n        if rows.shape[0] > args.attempt_cache:\n            random_perm = t.randperm(rows.shape[0], device=args.devices[0])\n            pck_perm = random_perm[:args.attempt_cache]\n            rows = rows[pck_perm]\n            cols = cols[pck_perm]\n            negs = negs[pck_perm]\n        while True:\n            try:\n                row_embeds = final_embeds[rows]\n                col_embeds = final_embeds[cols]\n                neg_embeds = final_embeds[negs]\n                score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item()\n                break\n            except Exception:\n                args.attempt_cache = args.attempt_cache // 2\n                random_perm = t.randperm(rows.shape[0], device=args.devices[0])\n                pck_perm = random_perm[:args.attempt_cache]\n                rows = rows[pck_perm]\n                cols = cols[pck_perm]\n                negs = negs[pck_perm]\n        t.cuda.empty_cache()\n        return score\n\nclass AnyGraph(nn.Module):\n    def __init__(self):\n        super(AnyGraph, self).__init__()\n        self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda()\n        self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts))\n        \n    def assign_experts(self, handlers, reca=True, log_assignment=False):\n        if args.expert_num == 1:\n            self.assignment = [0] * len(handlers)\n            return\n        try:\n            expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts)))\n            expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2\n        except Exception:\n            expert_scores = np.ones(len(self.experts))\n        with t.no_grad():\n            assignment = [list() for i in range(len(handlers))]\n            for dataset_id, handler in enumerate(handlers):\n                topo_embeds = handler.projectors.to(args.devices[1])\n                for expert_id, expert in enumerate(self.experts):\n                    expert = expert.to(args.devices[1])\n                    score = expert.attempt(topo_embeds, handler.trn_loader.dataset)\n                    if reca:\n                        score *= expert_scores[expert_id]\n                    assignment[dataset_id].append((expert_id, score))\n                assignment[dataset_id].sort(key=lambda x: x[1], reverse=True)\n            if log_assignment:\n                print('\\n----------\\nAssignment')\n                for dataset_id, handler in enumerate(handlers):\n                    out = ''\n                    for exp_idx in range(min(4, len(self.experts))):\n                        out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) '\n                    print(handler.data_name, out)\n                print('----------\\n')\n\n            self.assignment = list(map(lambda x: x[0][0], assignment))\n    \n    def summon(self, dataset_id):\n        return self.experts[self.assignment[dataset_id]]\n    \n    def summon_opt(self, dataset_id):\n        return self.opts[self.assignment[dataset_id]]"
  },
  {
    "path": "node_classification/Utils/TimeLogger.py",
    "content": "import datetime\n\nlogmsg = ''\ntimemark = dict()\nsaveDefault = False\ndef log(msg, save=None, oneline=False):\n\tglobal logmsg\n\tglobal saveDefault\n\ttime = datetime.datetime.now()\n\ttem = '%s: %s' % (time, msg)\n\tif save != None:\n\t\tif save:\n\t\t\tlogmsg += tem + '\\n'\n\telif saveDefault:\n\t\tlogmsg += tem + '\\n'\n\tif oneline:\n\t\tprint(tem, end='\\r')\n\telse:\n\t\tprint(tem)\n\ndef marktime(marker):\n\tglobal timemark\n\ttimemark[marker] = datetime.datetime.now()\n\n\nif __name__ == '__main__':\n\tlog('')"
  },
  {
    "path": "node_classification/data_handler.py",
    "content": "import pickle\nimport numpy as np\nfrom scipy.sparse import csr_matrix, coo_matrix, dok_matrix\nfrom params import args\nimport scipy.sparse as sp\nfrom Utils.TimeLogger import log\nimport torch as t\nimport torch.utils.data as data\nimport torch_geometric.transforms as T\nfrom model import Feat_Projector, Adj_Projector, TopoEncoder\nimport os\n\nclass MultiDataHandler:\n    def __init__(self, trn_datasets, tst_datasets_group):\n        all_datasets = trn_datasets\n        all_tst_datasets = []\n        for tst_datasets in tst_datasets_group:\n            all_datasets = all_datasets + tst_datasets\n            all_tst_datasets += tst_datasets\n        all_datasets = list(set(all_datasets))\n        all_datasets.sort()\n        self.trn_handlers = []\n        self.tst_handlers_group = [list() for i in range(len(tst_datasets_group))]\n        for data_name in all_datasets:\n            trn_flag = data_name in trn_datasets\n            tst_flag = data_name in all_tst_datasets\n            handler = DataHandler(data_name)\n            if trn_flag:\n                self.trn_handlers.append(handler)\n            if tst_flag:\n                for i in range(len(tst_datasets_group)):\n                    if data_name in tst_datasets_group[i]:\n                        self.tst_handlers_group[i].append(handler)\n        self.make_joint_trn_loader()\n    \n    def make_joint_trn_loader(self):\n        loader_datasets = []\n        for trn_handler in self.trn_handlers:\n            tem_dataset = trn_handler.trn_loader.dataset\n            loader_datasets.append(tem_dataset)\n        joint_dataset = JointTrnData(loader_datasets)\n        self.joint_trn_loader = data.DataLoader(joint_dataset, batch_size=1, shuffle=True, num_workers=0)\n    \n    def remake_initial_projections(self):\n        for i in range(len(self.trn_handlers)):\n            trn_handler = self.trn_handlers[i]\n            trn_handler.make_projectors()\n\nclass DataHandler:\n    def __init__(self, data_name):\n        self.data_name = data_name\n        self.get_data_files()\n        log(f'Loading dataset {data_name}')\n        self.topo_encoder = TopoEncoder()\n        self.load_data()\n    \n    def get_data_files(self):\n        predir = f'/home/user_name/data/zero-shot datasets/node_data/{self.data_name}/'\n        # predir = f'../handle_node_data/{self.data_name}/'\n        if os.path.exists(predir + 'feats.pkl'):\n            self.feat_file = predir + 'feats.pkl'\n        else:\n            self.feat_file = None\n        self.trnfile = predir + 'trn_mat.pkl'\n        self.tstfile = predir + 'tst_mat.pkl'\n        self.fewshotfile = predir + 'fewshot_mat_{shot}.pkl'.format(shot=args.shot)\n        self.valfile = predir + 'val_mat.pkl'\n\n    def load_one_file(self, filename):\n        with open(filename, 'rb') as fs:\n            ret = (pickle.load(fs) != 0).astype(np.float32)\n        if type(ret) != coo_matrix:\n            ret = sp.coo_matrix(ret)\n        return ret\n    \n    def load_feats(self, filename):\n        try:\n            with open(filename, 'rb') as fs:\n                feats = pickle.load(fs)\n        except Exception as e:\n            print(filename + str(e))\n            exit()\n        return feats\n\n    def normalize_adj(self, mat, log=False):\n        degree = np.array(mat.sum(axis=-1))\n        d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])\n        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0\n        d_inv_sqrt_mat = sp.diags(d_inv_sqrt)\n        if mat.shape[0] == mat.shape[1]:\n            return mat.dot(d_inv_sqrt_mat).transpose().dot(d_inv_sqrt_mat).tocoo()\n        else:\n            tem = d_inv_sqrt_mat.dot(mat)\n            col_degree = np.array(mat.sum(axis=0))\n            d_inv_sqrt = np.reshape(np.power(col_degree, -0.5), [-1])\n            d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0\n            d_inv_sqrt_mat = sp.diags(d_inv_sqrt)\n            return tem.dot(d_inv_sqrt_mat).tocoo()\n    \n    def unique_numpy(self, row, col):\n        hash_vals = row * args.node_num + col\n        hash_vals = np.unique(hash_vals).astype(np.int64)\n        col = hash_vals % args.node_num\n        row = (hash_vals - col).astype(np.int64) // args.node_num\n        return row, col\n\n    def make_torch_adj(self, mat, unidirectional_for_asym=False):\n        if mat.shape[0] == mat.shape[1]:\n            _row = mat.row\n            _col = mat.col\n            row = np.concatenate([_row, _col]).astype(np.int64)\n            col = np.concatenate([_col, _row]).astype(np.int64)\n            row, col = self.unique_numpy(row, col)\n            data = np.ones_like(row)\n            mat = coo_matrix((data, (row, col)), mat.shape)\n            if args.selfloop == 1:\n                mat = (mat + sp.eye(mat.shape[0])) * 1.0\n            normed_asym_mat = self.normalize_adj(mat)\n            row = t.from_numpy(normed_asym_mat.row).long()\n            col = t.from_numpy(normed_asym_mat.col).long()\n            idxs = t.stack([row, col], dim=0)\n            vals = t.from_numpy(normed_asym_mat.data).float()\n            shape = t.Size(normed_asym_mat.shape)\n            asym_adj = t.sparse.FloatTensor(idxs, vals, shape)\n            return asym_adj\n        elif unidirectional_for_asym:\n            mat = (mat != 0) * 1.0\n            mat = self.normalize_adj(mat, log=True)\n            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))\n            vals = t.from_numpy(mat.data.astype(np.float32))\n            shape = t.Size(mat.shape)\n            return t.sparse.FloatTensor(idxs, vals, shape)\n        else:\n            # make ui adj\n            a = sp.csr_matrix((args.user_num, args.user_num))\n            b = sp.csr_matrix((args.item_num, args.item_num))\n            mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])\n            mat = (mat != 0) * 1.0\n            if args.selfloop == 1:\n                mat = (mat + sp.eye(mat.shape[0])) * 1.0\n            mat = self.normalize_adj(mat)\n\n            # make cuda tensor\n            idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))\n            vals = t.from_numpy(mat.data.astype(np.float32))\n            shape = t.Size(mat.shape)\n            return t.sparse.FloatTensor(idxs, vals, shape)\n\n    def load_data(self):\n        tst_mat = self.load_one_file(self.tstfile)\n        trn_mat = self.load_one_file(self.trnfile)\n        # fewshot_mat = self.load_one_file(self.fewshotfile)\n        if self.feat_file is not None:\n            self.feats = t.from_numpy(self.load_feats(self.feat_file)).float()\n            self.feats = self.feats\n            args.featdim = self.feats.shape[1]\n        else:\n            self.feats = None\n            args.featdim = args.latdim\n\n        if trn_mat.shape[0] != trn_mat.shape[1]:\n            args.user_num, args.item_num = trn_mat.shape\n            args.node_num = args.user_num + args.item_num\n            print('Dataset: {data_name}, User num: {user_num}, Item num: {item_num}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, user_num=args.user_num, item_num=args.item_num, node_num=args.node_num, edge_num=trn_mat.nnz))\n        else:\n            args.node_num = trn_mat.shape[0]\n            print('Dataset: {data_name}, Node num: {node_num}, Edge num: {edge_num}'.format(data_name=self.data_name, node_num=args.node_num, edge_num=trn_mat.nnz+tst_mat.nnz))\n        if args.tst_mode == 'tst':\n            tst_data = NodeTstData(tst_mat)\n            self.tst_loader = data.DataLoader(tst_data, batch_size=args.tst_batch, shuffle=False, num_workers=0)\n            # tst_loss_data = TrnData(tst_mat)\n            # self.tst_loss_loader = data.DataLoader(tst_loss_data, batch_size=args.batch, shuffle=False, num_workers=0)\n            self.tst_input_adj = self.make_torch_adj(trn_mat)\n        elif args.tst_mode == 'val':\n            raise Exception('Validation not made for node classification')\n        else:\n            raise Exception('Specify proper test mode')\n\n\n        self.trn_mat = trn_mat\n        trn_data = TrnData(self.trn_mat)\n        self.trn_loader = data.DataLoader(trn_data, batch_size=args.batch, shuffle=True, num_workers=0)\n        if args.tst_mode == 'tst':\n            self.trn_input_adj = self.tst_input_adj\n        else:\n            self.trn_input_adj = self.make_torch_adj(trn_mat)\n\n        if self.trn_mat.shape[0] == self.trn_mat.shape[1]:\n            self.asym_adj = self.trn_input_adj\n        else:\n            self.asym_adj = self.make_torch_adj(self.trn_mat, unidirectional_for_asym=True)\n        self.make_projectors()\n        self.reproj_steps = max(len(self.trn_loader.dataset) // (10 * args.batch), args.proj_trn_steps)\n        self.ratio_500_all = 250 / len(self.trn_loader)\n    \n    def make_projectors(self):\n        with t.no_grad():\n            projectors = []\n            if args.proj_method == 'adj_svd' or args.proj_method == 'both':\n                tem = self.asym_adj.to(args.devices[0])\n                projectors = [Adj_Projector(tem)]\n            if self.feats is not None and args.proj_method != 'adj_svd':\n                tem = self.feats.to(args.devices[0])\n                projectors.append(Feat_Projector(tem))\n            assert args.tst_mode == 'tst' and args.trn_mode == 'train-all' or args.tst_mode == 'val' and args.trn_mode == 'fewshot'\n            feats = projectors[0]()\n            if len(projectors) == 2:\n                feats2 = projectors[1]()\n                feats = feats + feats2\n\n            try:\n                self.projectors = self.topo_encoder(self.trn_input_adj.to(args.devices[0]), feats.to(args.devices[0])).detach().cpu()\n            except Exception:\n                print(f'{self.data_name} memory overflow')\n                mean, std = feats.mean(dim=-1, keepdim=True), feats.std(dim=-1, keepdim=True)\n                tem_adj = self.trn_input_adj.to(args.devices[0])\n                mem_cache = 256\n                projectors_list = []\n                for i in range(feats.shape[1] // mem_cache):\n                    st, ed = i * mem_cache, (i + 1) * mem_cache\n                    tem_feats = (feats[:, st:ed] - mean) / (std + 1e-8)\n                    tem_feats = self.topo_encoder(tem_adj, tem_feats.to(args.devices[0]), normed=True).detach().cpu()\n                    projectors_list.append(tem_feats)\n                self.projectors = t.concat(projectors_list, dim=-1)\n            t.cuda.empty_cache()\n\nclass TrnData(data.Dataset):\n    def __init__(self, coomat):\n        self.ancs, self.poss = coomat.row, coomat.col\n        # self.dokmat = set(list(map(lambda idx: (self.ancs[idx], self.poss[idx]), range(len(self.ancs)))))\n        # self.dokmat = coomat.todok()\n        self.negs = np.zeros(len(self.ancs)).astype(np.int32)\n        self.cand_num = coomat.shape[1]\n        self.neg_shift = 0 if coomat.shape[0] == coomat.shape[1] else coomat.shape[0]\n        self.poss = coomat.col + self.neg_shift\n        self.neg_sampling()\n    \n    def neg_sampling(self):\n        self.negs = np.random.randint(self.cand_num + self.neg_shift, size=self.poss.shape[0])\n        # self.negs = np.zeros_like(self.ancs)\n        # for i in range(len(self.ancs)):\n        #     u = self.ancs[i]\n        #     while True:\n        #         i_neg = np.random.randint(self.cand_num)\n        #         if (u, i_neg) not in self.dokmat:\n        #             break\n        #     self.negs[i] = i_neg\n        # self.negs += self.neg_shift\n\n    def __len__(self):\n        return len(self.ancs)\n    \n    def __getitem__(self, idx):\n        return self.ancs[idx], self.poss[idx] , self.negs[idx]\n\nclass NodeTstData(data.Dataset):\n    def __init__(self, tst_mat):\n        self.class_num = tst_mat.shape[1]\n        self.nodes, self.labels = tst_mat.row, tst_mat.col\n\n    def __len__(self):\n        return len(self.nodes)\n    \n    def __getitem__(self, idx):\n        return self.nodes[idx], self.labels[idx]\n\nclass JointTrnData(data.Dataset):\n    def __init__(self, dataset_list):\n        self.batch_dataset_ids = []\n        self.batch_st_ed_list = []\n        self.dataset_list = dataset_list\n        for dataset_id, dataset in enumerate(dataset_list):\n            samp_num = len(dataset) // args.batch + (1 if len(dataset) % args.batch != 0 else 0)\n            for j in range(samp_num):\n                self.batch_dataset_ids.append(dataset_id)\n                st = j * args.batch\n                ed = min((j + 1) * args.batch, len(dataset))\n                self.batch_st_ed_list.append((st, ed))\n    \n    def neg_sampling(self):\n        for dataset in self.dataset_list:\n            dataset.neg_sampling()\n\n    def __len__(self):\n        return len(self.batch_dataset_ids)\n    \n    def __getitem__(self, idx):\n        st, ed = self.batch_st_ed_list[idx]\n        dataset_id = self.batch_dataset_ids[idx]\n        return *self.dataset_list[dataset_id][st: ed], dataset_id\n"
  },
  {
    "path": "node_classification/main.py",
    "content": "import torch as t\nfrom torch import nn\nimport Utils.TimeLogger as logger\nfrom Utils.TimeLogger import log\nfrom params import args\nfrom model import Expert, Feat_Projector, Adj_Projector, AnyGraph\nfrom data_handler import MultiDataHandler, DataHandler\nimport numpy as np\nimport pickle\nimport os\nimport setproctitle\nimport time\nfrom sklearn.metrics import f1_score\n\nclass Exp:\n    def __init__(self, multi_handler):\n        self.multi_handler = multi_handler\n        print(list(map(lambda x: x.data_name, multi_handler.trn_handlers)))\n        for group_id, tst_handlers in enumerate(multi_handler.tst_handlers_group):\n            print(f'Test group {group_id}', list(map(lambda x: x.data_name, tst_handlers)))\n        self.metrics = dict()\n        trn_mets = ['Loss', 'preLoss']\n        tst_mets = ['Recall', 'NDCG', 'Loss', 'preLoss']\n        mets = trn_mets + tst_mets\n        for met in mets:\n            if met in trn_mets:\n                self.metrics['Train' + met] = list()\n            if met in tst_mets:\n                for i in range(len(self.multi_handler.tst_handlers_group)):\n                    self.metrics['Test' + str(i) + met] = list()\n        \n    def make_print(self, name, ep, reses, save, data_name=None):\n        if data_name is None:\n            ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)\n        else:\n            ret = 'Epoch %d/%d, %s %s: ' % (ep, args.epoch, data_name, name)\n        for metric in reses:\n            val = reses[metric]\n            ret += '%s = %.4f, ' % (metric, val)\n            tem = name + metric if data_name is None else name + data_name + metric\n            if save and tem in self.metrics:\n                self.metrics[tem].append(val)\n        ret = ret[:-2] + '      '\n        return ret\n    \n    def run(self):\n        self.prepare_model()\n        log('Model Prepared')\n        stloc = 0\n        if args.load_model != None:\n            self.load_model()\n            stloc = len(self.metrics['TrainLoss']) * args.tst_epoch - (args.tst_epoch - 1)\n        best_ndcg, best_ep = 0, -1\n        early_stop = False\n        for ep in range(stloc, args.epoch):\n            tst_flag = (ep % args.tst_epoch == 0)\n            start_time = time.time()\n            self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=True)\n            reses = self.train_epoch()\n            log(self.make_print('Train', ep, reses, tst_flag))\n            self.multi_handler.remake_initial_projections()\n            end_time = time.time()\n            print(f'NOTICE: {end_time-start_time}')\n            if tst_flag:\n                for handler_group_id in range(len(self.multi_handler.tst_handlers_group)):\n                    tst_handlers = self.multi_handler.tst_handlers_group[handler_group_id]\n                    self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)\n                    recall, ndcg, tstnum = 0, 0, 0\n                    for i, handler in enumerate(tst_handlers):\n                        reses = self.test_epoch(handler, i)\n                        log(self.make_print('handler.data_name', ep, reses, False))\n                        recall += reses['Recall'] * reses['tstNum']\n                        ndcg += reses['NDCG'] * reses['tstNum']\n                        tstnum += reses['tstNum']\n                    reses = {'Recall': recall / tstnum, 'NDCG': ndcg / tstnum}\n                    log(self.make_print('Test'+str(handler_group_id), ep, reses, tst_flag))\n\n                    if reses['NDCG'] > best_ndcg:\n                        best_ndcg = reses['NDCG']\n                        best_ep = ep\n                self.save_history()\n            print()\n\n        for test_group_id in range(len(self.multi_handler.tst_handlers_group)):\n            repeat_times = 10\n            overall_recall, overall_ndcg = np.zeros(repeat_times), np.zeros(repeat_times)\n            overall_tstnum = 0\n            tst_handlers = self.multi_handler.tst_handlers_group[test_group_id]\n            if args.assignment == 'one-graph-one-expert':\n                self.model.assign_experts(tst_handlers, reca=False, log_assignment=True)\n            for i, handler in enumerate(tst_handlers):\n                    mets = dict()\n                    for _ in range(repeat_times):\n                        handler.make_projectors()\n                        if not args.assignment == 'one-graph-one-expert':\n                            self.model.assign_experts([handler], reca=False, log_assignment=False)\n                        reses = self.test_epoch(handler, i if args.assignment == 'one-graph-one-expert' else 0)\n                        log(self.make_print('Test', args.epoch, reses, False))\n                        for met in reses:\n                            if met not in mets:\n                                mets[met] = []\n                            mets[met].append(reses[met])\n                    tstnum = reses['tstNum']\n                    tot_reses = dict()\n                    for met in reses:\n                        tem_arr = np.array(mets[met])\n                        tot_reses[met + '_std'] = tem_arr.std()\n                        tot_reses[met + '_mean'] = tem_arr.mean()\n                    \n                    overall_recall += np.array(mets['Acc']) * tstnum\n                    overall_ndcg += np.array(mets['F1']) * tstnum\n                    overall_tstnum += tstnum\n                    log(self.make_print(f'Test', args.epoch, tot_reses, False, handler.data_name))\n            overall_recall /= overall_tstnum\n            overall_ndcg /= overall_tstnum\n            overall_res = dict()\n            overall_res['Recall_mean'] = overall_recall.mean()\n            overall_res['Recall_std'] = overall_recall.std()\n            overall_res['NDCG_mean'] = overall_ndcg.mean()\n            overall_res['NDCG_std'] = overall_ndcg.std()\n            log(self.make_print('Overall Test', args.epoch, overall_res, False))\n        self.save_history()\n\n    def print_model_size(self):\n        total_params = 0\n        trainable_params = 0\n        non_trainable_params = 0\n        for param in self.model.parameters():\n            tem = np.prod(param.size())\n            total_params += tem\n            if param.requires_grad:\n                trainable_params += tem\n            else:\n                non_trainable_params += tem\n        print(f'Total params: {total_params/1e6}')\n        print(f'Trainable params: {trainable_params/1e6}')\n        print(f'Non-trainable params: {non_trainable_params/1e6}')\n\n    def prepare_model(self):\n        self.model = AnyGraph()\n        t.cuda.empty_cache()\n        self.print_model_size()\n\n    def train_epoch(self):\n        self.model.train()\n        trn_loader = self.multi_handler.joint_trn_loader\n        trn_loader.dataset.neg_sampling()\n        ep_loss, ep_preloss, ep_regloss = 0, 0, 0\n        steps = len(trn_loader)\n        tot_samp_num = 0\n        counter = [0] * len(self.multi_handler.trn_handlers)\n        reassign_steps = sum(list(map(lambda x: x.reproj_steps, self.multi_handler.trn_handlers)))\n        for i, batch_data in enumerate(trn_loader):\n            if args.epoch_max_step > 0 and i >= args.epoch_max_step:\n                break\n            ancs, poss, negs, dataset_id = batch_data\n            ancs = ancs[0].long()\n            poss = poss[0].long()\n            negs = negs[0].long()\n            dataset_id = dataset_id[0].long()\n            tem_bar = self.multi_handler.trn_handlers[dataset_id].ratio_500_all\n            if tem_bar < 1.0 and np.random.uniform() > tem_bar:\n                steps -= 1\n                continue\n\n            expert = self.model.summon(dataset_id)#.cuda()\n            opt = self.model.summon_opt(dataset_id)\n            # adj = self.multi_handler.trn_handlers[dataset_id].trn_input_adj\n            feats = self.multi_handler.trn_handlers[dataset_id].projectors\n            loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)\n            opt.zero_grad()\n            loss.backward()\n            # nn.utils.clip_grad_norm_(expert.parameters(), max_norm=20, norm_type=2)\n            opt.step()\n\n            sample_num = ancs.shape[0]\n            tot_samp_num += sample_num\n            ep_loss += loss.item() * sample_num\n            ep_preloss += loss_dict['preloss'].item() * sample_num\n            ep_regloss += loss_dict['regloss'].item()\n            log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f        ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)\n\n            counter[dataset_id] += 1\n            if (counter[dataset_id] + 1) % self.multi_handler.trn_handlers[dataset_id].reproj_steps == 0:\n            # if args.proj_trn_steps > 0 and counter[dataset_id] >= args.proj_trn_steps:\n                self.multi_handler.trn_handlers[dataset_id].make_projectors()\n            if (i + 1) % reassign_steps == 0:\n                self.model.assign_experts(self.multi_handler.trn_handlers, reca=True, log_assignment=False)\n        ret = dict()\n        ret['Loss'] = ep_loss / tot_samp_num\n        ret['preLoss'] = ep_preloss / tot_samp_num\n        ret['regLoss'] = ep_regloss / steps\n        t.cuda.empty_cache()\n        return ret\n    \n    def make_trn_masks(self, numpy_usrs, csr_mat):\n        trn_masks = csr_mat[numpy_usrs].tocoo()\n        cand_size = trn_masks.shape[1]\n        trn_masks = t.from_numpy(np.stack([trn_masks.row, trn_masks.col], axis=0)).long()\n        return trn_masks, cand_size\n\n    def test_loss_epoch(self, handler, dataset_id):\n        with t.no_grad():\n            tst_loader = handler.tst_loss_loader\n            self.model.eval()\n            expert = self.model.summon(dataset_id)#.cuda()\n            ep_loss, ep_preloss, ep_regloss = 0, 0, 0\n            steps = len(tst_loader)\n            tot_samp_num = 0\n            for i, batch_data in enumerate(tst_loader):\n                ancs, poss, negs = batch_data\n                ancs = ancs.long()\n                poss = poss.long()\n                negs = negs.long()\n                # adj = handler.tst_input_adj\n                feats = handler.projectors\n                loss, loss_dict = expert.cal_loss((ancs, poss, negs), feats)\n                \n                sample_num = ancs.shape[0]\n                tot_samp_num += sample_num\n                ep_loss += loss.item() * sample_num\n                ep_preloss += loss_dict['preloss'].item() * sample_num\n                ep_regloss += loss_dict['regloss'].item()\n                log('Step %d/%d: loss = %.3f, pre = %.3f, reg = %.3f, pos = %.3f, neg = %.3f        ' % (i, steps, loss, loss_dict['preloss'], loss_dict['regloss'], loss_dict['posloss'], loss_dict['negloss']), save=False, oneline=True)\n\n        ret = dict()\n        ret['Loss'] = ep_loss / tot_samp_num\n        ret['preLoss'] = ep_preloss / tot_samp_num\n        ret['regLoss'] = ep_regloss / steps\n        ret['tot_samp_num'] = tot_samp_num\n        t.cuda.empty_cache()\n        return ret\n    \n    def test_epoch(self, handler, dataset_id):\n        with t.no_grad():\n            tst_loader = handler.tst_loader\n            class_num = tst_loader.dataset.class_num\n            self.model.eval()\n            expert = self.model.summon(dataset_id)\n            ep_acc, ep_tot = 0, 0\n            steps = len(tst_loader)\n            for i, batch_data in enumerate(tst_loader):\n                nodes, labels = list(map(lambda x: x.long().cuda(), batch_data))\n                feats = handler.projectors\n                preds = expert.pred_for_node_test(nodes, class_num, feats, rerun_embed=False if i!=0 else True)\n                if i == 0:\n                    all_preds, all_labels = preds, labels\n                else:\n                    all_preds = t.concatenate([all_preds, preds])\n                    all_labels = t.concatenate([all_labels, labels])\n                hit = (labels == preds).float().sum().item()\n                ep_acc += hit\n                ep_tot += labels.shape[0]\n                log('Steps %d/%d: hit = %d, tot = %d          ' % (i, steps, ep_acc, ep_tot), save=False, oneline=True)\n        ret = dict()\n        ret['Acc'] = ep_acc / ep_tot\n        ret['F1'] = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')\n        ret['tstNum'] = ep_tot\n        t.cuda.empty_cache()\n        return ret\n\n    \n    def calc_recall_ndcg(self, topLocs, tstLocs, batIds):\n        assert topLocs.shape[0] == len(batIds)\n        allRecall = allNdcg = 0\n        for i in range(len(batIds)):\n            temTopLocs = list(topLocs[i])\n            temTstLocs = tstLocs[batIds[i]]\n            tstNum = len(temTstLocs)\n            maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])\n            recall = dcg = 0\n            for val in temTstLocs:\n                if val in temTopLocs:\n                    recall += 1\n                    dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))\n            recall = recall / tstNum\n            ndcg = dcg / maxDcg\n            allRecall += recall\n            allNdcg += ndcg\n        return allRecall, allNdcg\n    \n    def save_history(self):\n        if args.epoch == 0:\n            return\n        with open('../History/' + args.save_path + '.his', 'wb') as fs:\n            pickle.dump(self.metrics, fs)\n\n        content = {\n            'model': self.model,\n        }\n        t.save(content, '../Models/' + args.save_path + '.mod')\n        log('Model Saved: %s' % args.save_path)\n\n    def load_model(self):\n        ckp = t.load('../Models/' + args.load_model + '.mod')\n        self.model = ckp['model']\n        # self.model.set_initial_projection(self.handler.torch_adj)\n        self.opt = t.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)\n\n        with open('../History/' + args.load_model + '.his', 'rb') as fs:\n            self.metrics = pickle.load(fs)\n        log('Model Loaded')\n\nif __name__ == '__main__':\n    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n    if len(args.gpu.split(',')) == 2:\n        args.devices = ['cuda:0', 'cuda:1']\n    elif len(args.gpu.split(',')) > 2:\n        raise Exception('Devices should be less than 2')\n    else:\n        args.devices = ['cuda:0', 'cuda:0']\n    logger.saveDefault = True\n    setproctitle.setproctitle('akaxia_AutoGraph')\n\n    log('Start')\n\n    \n    datasets = dict()\n    datasets['all'] = [\n        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech', 'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab', 'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1',\n    ]\n    datasets['ecommerce'] = [\n        'amazon-book', 'yelp2018', 'gowalla', 'yelp_textfeat', 'amazon_textfeat', 'steam_textfeat', 'Goodreads', 'Fitness', 'Photo', 'ml1m', 'ml10m', 'products_home', 'products_tech'\n    ]\n    datasets['academic'] = [\n        'cora', 'pubmed', 'citeseer', 'CS', 'arxiv', 'arxiv-ta', 'citation-2019', 'citation-classic', 'collab'\n    ]\n    datasets['others'] = [\n        'ddi', 'ppa', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'email-Enron', 'web-Stanford', 'roadNet-PA', 'p2p-Gnutella06', 'soc-Epinions1'\n    ]\n    datasets['div1'] = [\n        'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat', 'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06', 'soc-Epinions1', 'email-Enron',\n    ]\n    datasets['div2'] = [\n        'Photo', 'Goodreads', 'Fitness', 'ml1m', 'ml10m', 'gowalla', 'arxiv', 'arxiv-ta', 'cora', 'CS', 'collab', 'proteins_spec0', 'proteins_spec1', 'proteins_spec2', 'proteins_spec3', 'ddi', 'web-Stanford', 'roadNet-PA',\n    ]\n    datasets['node'] = [\n        'cora', 'arxiv', 'pubmed', 'home', 'tech'\n    ]\n\n    if args.dataset_setting in datasets.keys():\n        trn_datasets = tst_datasets = datasets[args.dataset_setting]\n    elif args.dataset_setting in datasets['all']:\n        trn_datasets = tst_datasets = [args.dataset_setting]\n    elif '+' in args.dataset_setting:\n        idx = args.dataset_setting.index('+')\n        trn_datasets = datasets[args.dataset_setting[:idx]]\n        tst_datasets = datasets[args.dataset_setting[idx+1:]]\n    elif '_in_' in args.dataset_setting:\n        idx = args.dataset_setting.index('_in_')\n        tst_datasets_1 = datasets[args.dataset_setting[:idx]]\n        tst_datasets_2 = datasets[args.dataset_setting[idx+len('_in_'):]]\n        tst_datasets = []\n        for data in tst_datasets_1:\n            if data in tst_datasets_2:\n                tst_datasets.append(data)\n        trn_datasets = tst_datasets\n\n    # trn_datasets = tst_datasets = ['products_home']\n    if '+' not in args.dataset_setting:\n        handler = MultiDataHandler(trn_datasets, [tst_datasets])\n    else:\n        handler = MultiDataHandler(trn_datasets, [trn_datasets, tst_datasets])\n    log('Load Data')\n\n    exp = Exp(handler)\n    exp.run()\n    print(args.load_model, args.dataset_setting)\n"
  },
  {
    "path": "node_classification/model.py",
    "content": "import torch as t\nfrom torch import nn\nimport torch.nn.functional as F\nfrom params import args\nimport numpy as np\nfrom Utils.TimeLogger import log\nfrom torch.nn import MultiheadAttention\nfrom time import time\n\ninit = nn.init.xavier_uniform_\n# init = nn.init.normal_\nuniformInit = nn.init.uniform_\n\nclass FeedForwardLayer(nn.Module):\n    def __init__(self, in_feat, out_feat, bias=True, act=None):\n        super(FeedForwardLayer, self).__init__()\n        self.linear = nn.Linear(in_feat, out_feat, bias=bias)#, dtype=t.bfloat16)\n        if act == 'identity' or act is None:\n            self.act = None\n        elif act == 'leaky':\n            self.act = nn.LeakyReLU(negative_slope=args.leaky)\n        elif act == 'relu':\n            self.act = nn.ReLU()\n        elif act == 'relu6':\n            self.act = nn.ReLU6()\n        else:\n            raise Exception('Error')\n    \n    def forward(self, embeds):\n        if self.act is None:\n            return self.linear(embeds)\n        return self.act(self.linear(embeds))\n\nclass TopoEncoder(nn.Module):\n    def __init__(self):\n        super(TopoEncoder, self).__init__()\n\n        self.layer_norm = nn.LayerNorm(args.latdim, elementwise_affine=False)\n\n    def forward(self, adj, embeds, normed=False):\n        with t.no_grad():\n            if not normed:\n                embeds = self.layer_norm(embeds)\n            # embeds_list = []\n            final_embeds = 0\n            if args.gnn_layer == 0:\n                final_embeds = embeds\n                # embeds_list.append(embeds)\n            for _ in range(args.gnn_layer):\n                embeds = t.spmm(adj, embeds)\n                final_embeds += embeds\n                # embeds_list.append(embeds)\n            embeds = final_embeds#sum(embeds_list)\n        return embeds\n\nclass MLP(nn.Module):\n    def __init__(self):\n        super(MLP, self).__init__()\n        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(args.fc_layer)])\n        self.layer_norms = nn.Sequential(*[nn.LayerNorm(args.latdim, elementwise_affine=True) for _ in range(args.fc_layer)])\n        self.dropout = nn.Dropout(p=args.drop_rate)\n    \n    def forward(self, embeds):\n        for i in range(args.fc_layer):\n            embeds = self.layer_norms[i](self.dropout(self.dense_layers[i](embeds)) + embeds)\n        return embeds\n\nclass GTLayer(nn.Module):\n    def __init__(self):\n        super(GTLayer, self).__init__()\n        self.multi_head_attention = MultiheadAttention(args.latdim, args.head, dropout=0.1, bias=False)#, dtype=t.bfloat16)\n        self.dense_layers = nn.Sequential(*[FeedForwardLayer(args.latdim, args.latdim, bias=True, act=args.act) for _ in range(2)])# bias=False\n        self.layer_norm1 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)\n        self.layer_norm2 = nn.LayerNorm(args.latdim, elementwise_affine=True)#, dtype=t.bfloat16)\n        self.fc_dropout = nn.Dropout(p=args.drop_rate)\n    \n    def _pick_anchors(self, embeds):\n        perm = t.randperm(embeds.shape[0])\n        anchors = perm[:args.anchor]\n        return embeds[anchors]\n    \n    def forward(self, embeds):\n        anchor_embeds = self._pick_anchors(embeds)\n        _anchor_embeds, _ = self.multi_head_attention(anchor_embeds, embeds, embeds)\n        anchor_embeds = _anchor_embeds + anchor_embeds\n        _embeds, _ = self.multi_head_attention(embeds, anchor_embeds, anchor_embeds, need_weights=False)\n        embeds = self.layer_norm1(_embeds + embeds)\n        _embeds = self.fc_dropout(self.dense_layers(embeds))\n        embeds = (self.layer_norm2(_embeds + embeds))\n        return embeds\n\nclass GraphTransformer(nn.Module):\n    def __init__(self):\n        super(GraphTransformer, self).__init__()\n        self.gt_layers = nn.Sequential(*[GTLayer() for i in range(args.gt_layer)])\n\n    def forward(self, embeds):\n        for i, layer in enumerate(self.gt_layers):\n            embeds = layer(embeds) / args.scale_layer\n        return embeds\n\nclass Feat_Projector(nn.Module):\n    def __init__(self, feats):\n        super(Feat_Projector, self).__init__()\n\n        if args.proj_method == 'uniform':\n            self.proj_embeds = self.uniform_proj(feats)\n        elif args.proj_method == 'svd' or args.proj_method == 'both':\n            self.proj_embeds = self.svd_proj(feats)\n        elif args.proj_method == 'random':\n            self.proj_embeds = self.random_proj(feats)\n        elif args.proj_method == 'original':\n            self.proj_embeds = feats\n        self.proj_embeds = t.flip(self.proj_embeds, dims=[-1])\n        self.proj_embeds = self.proj_embeds.detach()\n    \n    def svd_proj(self, feats):\n        if args.latdim > feats.shape[0] or args.latdim > feats.shape[1]:\n            dim = min(feats.shape[0], feats.shape[1])\n            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=dim, niter=args.niter)\n            decom_feats = t.concat([decom_feats, t.zeros([decom_feats.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)\n            s = t.concat([s, t.zeros(args.latdim - dim).to(args.devices[0])])\n        else:\n            decom_feats, s, decom_featdim = t.svd_lowrank(feats, q=args.latdim, niter=args.niter)\n        decom_feats = decom_feats @ t.diag(t.sqrt(s))\n        return decom_feats.cpu()\n    \n    def uniform_proj(self, feats):\n        projection = init(t.empty(args.featdim, args.latdim))\n        return feats @ projection\n    \n    def random_proj(self, feats):\n        projection = init(t.empty(feats.shape[0], args.latdim))\n        return projection\n    \n    def forward(self):\n        return self.proj_embeds\n\nclass Adj_Projector(nn.Module):\n    def __init__(self, adj):\n        super(Adj_Projector, self).__init__()\n\n        if args.proj_method == 'adj_svd' or args.proj_method == 'both':\n            self.proj_embeds = self.svd_proj(adj)\n        self.proj_embeds = self.proj_embeds.detach()\n    \n    def svd_proj(self, adj):\n        q = args.latdim\n        if args.latdim > adj.shape[0] or args.latdim > adj.shape[1]:\n            dim = min(adj.shape[0], adj.shape[1])\n            svd_u, s, svd_v = t.svd_lowrank(adj, q=dim, niter=args.niter)\n            svd_u = t.concat([svd_u, t.zeros([svd_u.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)\n            svd_v = t.concat([svd_v, t.zeros([svd_v.shape[0], args.latdim-dim]).to(args.devices[0])], dim=1)\n            s = t.concat([s, t.zeros(args.latdim-dim).to(args.devices[0])])\n        else:\n            svd_u, s, svd_v = t.svd_lowrank(adj, q=q, niter=args.niter)\n        svd_u = svd_u @ t.diag(t.sqrt(s))\n        svd_v = svd_v @ t.diag(t.sqrt(s))\n        if adj.shape[0] != adj.shape[1]:\n            projection = t.concat([svd_u, svd_v], dim=0)\n        else:\n            projection = svd_u + svd_v\n        return projection.cpu()\n    \n    def forward(self):\n        return self.proj_embeds\n\nclass Expert(nn.Module):\n    def __init__(self):\n        super(Expert, self).__init__()\n        \n        self.topo_encoder = TopoEncoder().to(args.devices[0])\n        if args.nn == 'mlp':\n            self.trainable_nn = MLP().to(args.devices[1])\n        else:\n            self.trainable_nn = GraphTransformer().to(args.devices[1])\n        self.trn_count = 1\n    \n    def forward(self, projectors, pck_nodes=None):\n        embeds = projectors.to(args.devices[1])\n        if pck_nodes is not None:\n            embeds = embeds[pck_nodes]\n        embeds = self.trainable_nn(embeds)\n        return embeds\n\n    def pred_norm(self, pos_preds, neg_preds):\n        pos_preds_num = pos_preds.shape[0]\n        neg_preds_shape = neg_preds.shape\n        preds = t.concat([pos_preds, neg_preds.view(-1)])\n        preds = preds - preds.max()\n        pos_preds = preds[:pos_preds_num]\n        neg_preds = preds[pos_preds_num:].view(neg_preds_shape)\n        return pos_preds, neg_preds\n    \n    def cal_loss(self, batch_data, projectors):\n        ancs, poss, negs = list(map(lambda x: x.to(args.devices[1]), batch_data))\n        self.trn_count += ancs.shape[0]\n        pck_nodes = t.concat([ancs, poss, negs])\n        final_embeds = self.forward(projectors, pck_nodes)\n        # anc_embeds, pos_embeds, neg_embeds = final_embeds[ancs], final_embeds[poss], final_embeds[negs]\n        anc_embeds, pos_embeds, neg_embeds = t.split(final_embeds, [ancs.shape[0]] * 3)\n        if final_embeds.isinf().any() or final_embeds.isnan().any():\n            raise Exception('Final embedding fails')\n        \n        if args.loss == 'ce':\n            pos_preds, neg_preds = self.pred_norm((anc_embeds * pos_embeds).sum(-1), anc_embeds @ neg_embeds.T)\n            if pos_preds.isinf().any() or pos_preds.isnan().any() or neg_preds.isinf().any() or neg_preds.isnan().any():\n                raise Exception('Preds fails')\n            pos_loss = pos_preds\n            neg_loss = (neg_preds.exp().sum(-1) + pos_preds.exp() + 1e-8).log()\n            pre_loss = -(pos_loss - neg_loss).mean()\n        elif args.loss == 'bpr':\n            pos_preds = (anc_embeds * pos_embeds).sum(-1)\n            neg_preds = (anc_embeds * neg_embeds).sum(-1)\n            pos_loss, neg_loss = pos_preds, neg_preds\n            pre_loss = -((pos_preds - neg_preds).sigmoid() + 1e-10).log().mean() \n\n        if t.isinf(pre_loss).any() or t.isnan(pre_loss).any():\n            raise Exception('NaN or Inf')\n\n        reg_loss = sum(list(map(lambda W: W.norm(2).square() * args.reg, self.parameters())))\n        loss_dict = {'preloss': pre_loss, 'regloss': reg_loss, 'posloss': pos_loss.mean(), 'negloss': neg_loss.mean()}\n        return pre_loss + reg_loss, loss_dict\n    \n    def pred_for_test(self, batch_data, cand_size, projectors, rerun_embed=True):\n        ancs, trn_mask = list(map(lambda x: x.to(args.devices[1]), batch_data))\n        if rerun_embed:\n            try:\n                final_embeds = self.forward(projectors)\n            except Exception:\n                final_embeds_list = []\n                div = args.batch * 3\n                temlen = projectors.shape[0] // div\n                for i in range(temlen):\n                    st, ed = div * i, div * (i + 1)\n                    tem_projectors = projectors[st: ed, :]\n                    final_embeds_list.append(self.forward(tem_projectors))\n                if temlen * div < projectors.shape[0]:\n                    tem_projectors = projectors[temlen*div:, :]\n                    final_embeds_list.append(self.forward(tem_projectors))\n                final_embeds = t.concat(final_embeds_list, dim=0)\n            self.final_embeds = final_embeds\n        final_embeds = self.final_embeds\n        anc_embeds = final_embeds[ancs]\n        cand_embeds = final_embeds[-cand_size:]\n\n        mask_mat = t.sparse.FloatTensor(trn_mask, t.ones(trn_mask.shape[1]).to(args.devices[1]), t.Size([ancs.shape[0], cand_size]))\n        dense_mat = mask_mat.to_dense()\n        all_preds = anc_embeds @ cand_embeds.T * (1 - dense_mat) - dense_mat * 1e8\n        return all_preds\n    \n    def pred_for_node_test(self, nodes, cand_size, feats, rerun_embed=True):\n        if rerun_embed:\n            final_embeds = self.forward(feats)\n            self.final_embeds = final_embeds\n        final_embeds = self.final_embeds\n        anc_embeds = final_embeds[nodes]\n        class_embeds = final_embeds[-cand_size:]\n        preds = anc_embeds @ class_embeds.T\n        return t.argmax(preds, dim=-1)\n\n    def attempt(self, topo_embeds, dataset):\n        final_embeds = self.trainable_nn(topo_embeds)\n        rows, cols, negs = list(map(lambda x: t.from_numpy(x).long().to(args.devices[1]), [dataset.ancs, dataset.poss, dataset.negs]))\n        if rows.shape[0] > args.attempt_cache:\n            random_perm = t.randperm(rows.shape[0], device=args.devices[0])\n            pck_perm = random_perm[:args.attempt_cache]\n            rows = rows[pck_perm]\n            cols = cols[pck_perm]\n            negs = negs[pck_perm]\n        while True:\n            try:\n                row_embeds = final_embeds[rows]\n                col_embeds = final_embeds[cols]\n                neg_embeds = final_embeds[negs]\n                score = ((row_embeds * col_embeds).sum(-1) - (row_embeds * neg_embeds).sum(-1)).sigmoid().mean().item()\n                break\n            except Exception:\n                args.attempt_cache = args.attempt_cache // 2\n                random_perm = t.randperm(rows.shape[0], device=args.devices[0])\n                pck_perm = random_perm[:args.attempt_cache]\n                rows = rows[pck_perm]\n                cols = cols[pck_perm]\n                negs = negs[pck_perm]\n        t.cuda.empty_cache()\n        return score\n\nclass AnyGraph(nn.Module):\n    def __init__(self):\n        super(AnyGraph, self).__init__()\n        self.experts = nn.ModuleList([Expert() for _ in range(args.expert_num)]).cuda()\n        self.opts = list(map(lambda expert: t.optim.Adam(expert.parameters(), lr=args.lr, weight_decay=0), self.experts))\n\n    def one_graph_one_expert(self, handlers, log_assignment=False):\n        self.assignment = [i for i in range(len(handlers))]\n        if log_assignment:\n            print('\\n----------\\nAssignment')\n            for dataset_id, handler in enumerate(handlers):\n                print(handler.data_name, f'{self.assignment[dataset_id]}')\n    \n    def store_history_assignment(self, assignment):\n        pass\n        \n    def assign_experts(self, handlers, reca=True, log_assignment=False):\n        if args.expert_num == 1:\n            self.assignment = [0] * len(handlers)\n            return\n        if args.assignment == 'one-graph-one-expert':\n            self.one_graph_one_expert(handlers, log_assignment)\n            return\n        if args.assignment not in ['top1', 'top2']:\n            raise Exception('Unrecognized assigning methods')\n        try:\n            expert_scores = np.array(list(map(lambda expert: expert.trn_count, self.experts)))\n            expert_scores = (1.0 - expert_scores / np.sum(expert_scores)) * args.reca_range + 1.0 - args.reca_range / 2\n        except Exception:\n            expert_scores = np.ones(len(self.experts))\n        with t.no_grad():\n            assignment = [list() for i in range(len(handlers))]\n            for dataset_id, handler in enumerate(handlers):\n                topo_embeds = handler.projectors.to(args.devices[1])\n                for expert_id, expert in enumerate(self.experts):\n                    expert = expert.to(args.devices[1])\n                    score = expert.attempt(topo_embeds, handler.trn_loader.dataset)\n                    if reca:\n                        score *= expert_scores[expert_id]\n                    assignment[dataset_id].append((expert_id, score))\n                assignment[dataset_id].sort(key=lambda x: x[1], reverse=True)\n                if args.assignment == 'top2':\n                    if assignment[dataset_id][0][1] - assignment[dataset_id][1][1] < 0.05 and np.random.uniform() > 0.5:\n                        tem = assignment[dataset_id][0]\n                        assignment[dataset_id][0] = assignment[dataset_id][1]\n                        assignment[dataset_id][1] = tem\n            if log_assignment:\n                print('\\n----------\\nAssignment')\n                for dataset_id, handler in enumerate(handlers):\n                    out = ''\n                    for exp_idx in range(min(4, len(self.experts))):\n                        out += f'({assignment[dataset_id][exp_idx][0]}, {assignment[dataset_id][exp_idx][1]}) '\n                    print(handler.data_name, out)\n                print('----------\\n')\n\n            self.assignment = list(map(lambda x: x[0][0], assignment))\n    \n    def summon(self, dataset_id):\n        return self.experts[self.assignment[dataset_id]]\n    \n    def summon_opt(self, dataset_id):\n        return self.opts[self.assignment[dataset_id]]\n"
  },
  {
    "path": "node_classification/params.py",
    "content": "import argparse\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Model Parameters')\n    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')\n    parser.add_argument('--batch', default=4096, type=int, help='training batch size')\n    parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')\n    parser.add_argument('--epoch', default=100, type=int, help='number of epochs')\n    parser.add_argument('--save_path', default='tem', help='file name to save model and training record')\n    parser.add_argument('--load_model', default=None, help='model name to load')\n    parser.add_argument('--data', default='ml1m', type=str, help='name of dataset')\n    parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')\n    parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')\n    parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')\n    parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')\n    parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]')\n    parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]')\n    parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance')\n    parser.add_argument('--ratio_fewshot_set', default=0.5, type=float, help='ratio of fewshot set')\n    parser.add_argument('--shot', default=5, type=int, help='number of shots for each node')\n    parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all')\n\n    parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')\n    parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality')\n    parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers')\n    parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers')\n    parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers')\n    parser.add_argument('--head', default=4, type=int, help='number of attention heads')\n    parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')\n    parser.add_argument('--act', default='relu', type=str, help='activation function')\n    parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use')\n    parser.add_argument('--assignment', default='top1', type=str, help='assigning method')\n    parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')\n    parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')\n    parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout')\n    parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration')\n    parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')\n    parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd')\n    parser.add_argument('--expert_num', default=8, type=int, help='number of experts')\n    parser.add_argument('--loss', default='ce', type=str, help='loss function')\n    parser.add_argument('--proj_method', default='both', type=str, help='feature projection method')\n    parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use')\n    parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection')\n    parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection')\n    return parser.parse_args()\nargs = parse_args()"
  },
  {
    "path": "params.py",
    "content": "import argparse\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Model Parameters')\n    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')\n    parser.add_argument('--batch', default=4096, type=int, help='training batch size')\n    parser.add_argument('--tst_batch', default=256, type=int, help='testing batch size (number of users)')\n    parser.add_argument('--epoch', default=100, type=int, help='number of epochs')\n    parser.add_argument('--save_path', default='tem', help='file name to save model and training record')\n    parser.add_argument('--load_model', default=None, help='model name to load')\n    parser.add_argument('--tst_epoch', default=1, type=int, help='number of epoch to test while training')\n    parser.add_argument('--gpu', default='0', type=str, help='indicates which gpu to use')\n    parser.add_argument('--topk', default=20, type=int, help='topk in evaluation')\n    parser.add_argument('--epoch_max_step', default=-1, type=int, help='indicates the maximum number of steps in one epoch, -1 denotes full steps')\n    parser.add_argument('--trn_mode', default='train-all', type=str, help='[fewshot, train-all]')\n    parser.add_argument('--tst_mode', default='tst', type=str, help='[tst, val]')\n    parser.add_argument('--eval_loss', default=True, type=bool, help='whether use CE loss to evaluate test performance')\n    parser.add_argument('--ratio_fewshot', default=0.1, type=float, help='ratio of fewshot set')\n    parser.add_argument('--tst_steps', default=-1, type=int, help='number of test steps, -1 indicates all')\n\n    parser.add_argument('--reg', default=1e-7, type=float, help='weight decay regularizer')\n    parser.add_argument('--latdim', default=512, type=int, help='latent dimensionality')\n    parser.add_argument('--gnn_layer', default=3, type=int, help='number of gnn layers')\n    parser.add_argument('--fc_layer', default=8, type=int, help='number of fully-connected layers')\n    parser.add_argument('--gt_layer', default=2, type=int, help='number of graph transformer layers')\n    parser.add_argument('--head', default=4, type=int, help='number of attention heads')\n    parser.add_argument('--anchor', default=256, type=int, help='number of anchor nodes in the compressed graph transformer')\n    parser.add_argument('--act', default='relu', type=str, help='activation function')\n    parser.add_argument('--dataset_setting', default='training', type=str, help='which set of datasets to use')\n    parser.add_argument('--scale_layer', default=10, type=float, help='per-layer scale factor')\n    parser.add_argument('--leaky', default=0.5, type=float, help='slope of leaky relu activation')\n    parser.add_argument('--drop_rate', default=0.1, type=float, help='ratio of dropout')\n    parser.add_argument('--reca_range', default=0.2, type=float, help='range of recalibration')\n    parser.add_argument('--selfloop', default=0, type=int, help='indicating using self-loop or not')\n    parser.add_argument('--niter', default=2, type=int, help='number of iterations in svd')\n    parser.add_argument('--expert_num', default=8, type=int, help='number of experts')\n    parser.add_argument('--loss', default='ce', type=str, help='loss function')\n    parser.add_argument('--proj_method', default='both', type=str, help='feature projection method')\n    parser.add_argument('--nn', default='mlp', type=str, help='what trainable network to use')\n    parser.add_argument('--proj_trn_steps', default=100, type=int, help='number of training steps for one initial projection')\n    parser.add_argument('--attempt_cache', default=10000000, type=int, help='number of training steps for one initial projection')\n    return parser.parse_args()\nargs = parse_args()"
  }
]