Repository: xxlya/BrainGNN_Pytorch Branch: main Commit: 1e337e7a13af Files: 24 Total size: 112.3 KB Directory structure: gitextract_qan2yim0/ ├── .idea/ │ ├── .gitignore │ ├── GNN_biomarker_MEDIA.iml │ ├── deployment.xml │ ├── encodings.xml │ ├── inspectionProfiles/ │ │ └── Project_Default.xml │ ├── misc.xml │ ├── modules.xml │ └── webServers.xml ├── 01-fetch_data.py ├── 02-process_data.py ├── 03-main.py ├── README.md ├── data/ │ └── subject_ID.txt ├── imports/ │ ├── ABIDEDataset.py │ ├── __inits__.py │ ├── gdc.py │ ├── preprocess_data.py │ ├── read_abide_stats_parall.py │ └── utils.py ├── net/ │ ├── braingnn.py │ ├── braingraphconv.py │ ├── brainmsgpassing.py │ └── inits.py └── requirements.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: .idea/.gitignore ================================================ # Default ignored files /shelf/ /workspace.xml # Datasource local storage ignored files /dataSources/ /dataSources.local.xml # Editor-based HTTP Client requests /httpRequests/ ================================================ FILE: .idea/GNN_biomarker_MEDIA.iml ================================================ ================================================ FILE: .idea/deployment.xml ================================================ ================================================ FILE: .idea/encodings.xml ================================================ ================================================ FILE: .idea/inspectionProfiles/Project_Default.xml ================================================ ================================================ FILE: .idea/misc.xml ================================================ ================================================ FILE: .idea/modules.xml ================================================ ================================================ FILE: .idea/webServers.xml ================================================ ================================================ FILE: 01-fetch_data.py ================================================ # Copyright (c) 2019 Mwiza Kunda # Copyright (C) 2017 Sarah Parisot , , Sofia Ira Ktena # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . ''' This script mainly refers to https://github.com/kundaMwiza/fMRI-site-adaptation/blob/master/fetch_data.py ''' from nilearn import datasets import argparse from imports import preprocess_data as Reader import os import shutil import sys # Input data variables code_folder = os.getcwd() root_folder = '/data/' data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/') if not os.path.exists(data_folder): os.makedirs(data_folder) shutil.copyfile(os.path.join(root_folder,'subject_ID.txt'), os.path.join(data_folder, 'subject_IDs.txt')) def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def main(): parser = argparse.ArgumentParser(description='Download ABIDE data and compute functional connectivity matrices') parser.add_argument('--pipeline', default='cpac', type=str, help='Pipeline to preprocess ABIDE data. Available options are ccs, cpac, dparsf and niak.' ' default: cpac.') parser.add_argument('--atlas', default='cc200', help='Brain parcellation atlas. Options: ho, cc200 and cc400, default: cc200.') parser.add_argument('--download', default=True, type=str2bool, help='Dowload data or just compute functional connectivity. default: True') args = parser.parse_args() print(args) params = dict() pipeline = args.pipeline atlas = args.atlas download = args.download # Files to fetch files = ['rois_' + atlas] filemapping = {'func_preproc': 'func_preproc.nii.gz', files[0]: files[0] + '.1D'} # Download database files if download == True: abide = datasets.fetch_abide_pcp(data_dir=root_folder, pipeline=pipeline, band_pass_filtering=True, global_signal_regression=False, derivatives=files, quality_checked=False) subject_IDs = Reader.get_ids() #changed path to data path subject_IDs = subject_IDs.tolist() # Create a folder for each subject for s, fname in zip(subject_IDs, Reader.fetch_filenames(subject_IDs, files[0], atlas)): subject_folder = os.path.join(data_folder, s) if not os.path.exists(subject_folder): os.mkdir(subject_folder) # Get the base filename for each subject base = fname.split(files[0])[0] # Move each subject file to the subject folder for fl in files: if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])): shutil.move(base + filemapping[fl], subject_folder) time_series = Reader.get_timeseries(subject_IDs, atlas) # Compute and save connectivity matrices Reader.subject_connectivity(time_series, subject_IDs, atlas, 'correlation') Reader.subject_connectivity(time_series, subject_IDs, atlas, 'partial correlation') if __name__ == '__main__': main() ================================================ FILE: 02-process_data.py ================================================ # Copyright (c) 2019 Mwiza Kunda # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import sys import argparse import pandas as pd import numpy as np from imports import preprocess_data as Reader import deepdish as dd import warnings import os warnings.filterwarnings("ignore") root_folder = '/data/' data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/') # Process boolean command line arguments def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def main(): parser = argparse.ArgumentParser(description='Classification of the ABIDE dataset using a Ridge classifier. ' 'MIDA is used to minimize the distribution mismatch between ABIDE sites') parser.add_argument('--atlas', default='cc200', help='Atlas for network construction (node definition) options: ho, cc200, cc400, default: cc200.') parser.add_argument('--seed', default=123, type=int, help='Seed for random initialisation. default: 1234.') parser.add_argument('--nclass', default=2, type=int, help='Number of classes. default:2') args = parser.parse_args() print('Arguments: \n', args) params = dict() params['seed'] = args.seed # seed for random initialisation # Algorithm choice params['atlas'] = args.atlas # Atlas for network construction atlas = args.atlas # Atlas for network construction (node definition) # Get subject IDs and class labels subject_IDs = Reader.get_ids() labels = Reader.get_subject_score(subject_IDs, score='DX_GROUP') # Number of subjects and classes for binary classification num_classes = args.nclass num_subjects = len(subject_IDs) params['n_subjects'] = num_subjects # Initialise variables for class labels and acquisition sites # 1 is autism, 2 is control y_data = np.zeros([num_subjects, num_classes]) # n x 2 y = np.zeros([num_subjects, 1]) # n x 1 # Get class labels for all subjects for i in range(num_subjects): y_data[i, int(labels[subject_IDs[i]]) - 1] = 1 y[i] = int(labels[subject_IDs[i]]) # Compute feature vectors (vectorised connectivity networks) fea_corr = Reader.get_networks(subject_IDs, iter_no='', kind='correlation', atlas_name=atlas) #(1035, 200, 200) fea_pcorr = Reader.get_networks(subject_IDs, iter_no='', kind='partial correlation', atlas_name=atlas) #(1035, 200, 200) if not os.path.exists(os.path.join(data_folder,'raw')): os.makedirs(os.path.join(data_folder,'raw')) for i, subject in enumerate(subject_IDs): dd.io.save(os.path.join(data_folder,'raw',subject+'.h5'),{'corr':fea_corr[i],'pcorr':fea_pcorr[i],'label':y[i]%2}) if __name__ == '__main__': main() ================================================ FILE: 03-main.py ================================================ import os import numpy as np import argparse import time import copy import torch import torch.nn.functional as F from torch.optim import lr_scheduler from tensorboardX import SummaryWriter from imports.ABIDEDataset import ABIDEDataset from torch_geometric.data import DataLoader from net.braingnn import Network from imports.utils import train_val_test_split from sklearn.metrics import classification_report, confusion_matrix torch.manual_seed(123) EPS = 1e-10 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") parser = argparse.ArgumentParser() parser.add_argument('--epoch', type=int, default=0, help='starting epoch') parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training') parser.add_argument('--batchSize', type=int, default=100, help='size of the batches') parser.add_argument('--dataroot', type=str, default='/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal', help='root directory of the dataset') parser.add_argument('--fold', type=int, default=0, help='training which fold') parser.add_argument('--lr', type = float, default=0.01, help='learning rate') parser.add_argument('--stepsize', type=int, default=20, help='scheduler step size') parser.add_argument('--gamma', type=float, default=0.5, help='scheduler shrinking rate') parser.add_argument('--weightdecay', type=float, default=5e-3, help='regularization') parser.add_argument('--lamb0', type=float, default=1, help='classification loss weight') parser.add_argument('--lamb1', type=float, default=0, help='s1 unit regularization') parser.add_argument('--lamb2', type=float, default=0, help='s2 unit regularization') parser.add_argument('--lamb3', type=float, default=0.1, help='s1 entropy regularization') parser.add_argument('--lamb4', type=float, default=0.1, help='s2 entropy regularization') parser.add_argument('--lamb5', type=float, default=0.1, help='s1 consistence regularization') parser.add_argument('--layer', type=int, default=2, help='number of GNN layers') parser.add_argument('--ratio', type=float, default=0.5, help='pooling ratio') parser.add_argument('--indim', type=int, default=200, help='feature dim') parser.add_argument('--nroi', type=int, default=200, help='num of ROIs') parser.add_argument('--nclass', type=int, default=2, help='num of classes') parser.add_argument('--load_model', type=bool, default=False) parser.add_argument('--save_model', type=bool, default=True) parser.add_argument('--optim', type=str, default='Adam', help='optimization method: SGD, Adam') parser.add_argument('--save_path', type=str, default='./model/', help='path to save model') opt = parser.parse_args() if not os.path.exists(opt.save_path): os.makedirs(opt.save_path) #################### Parameter Initialization ####################### path = opt.dataroot name = 'ABIDE' save_model = opt.save_model load_model = opt.load_model opt_method = opt.optim num_epoch = opt.n_epochs fold = opt.fold writer = SummaryWriter(os.path.join('./log',str(fold))) ################## Define Dataloader ################################## dataset = ABIDEDataset(path,name) dataset.data.y = dataset.data.y.squeeze() dataset.data.x[dataset.data.x == float('inf')] = 0 tr_index,val_index,te_index = train_val_test_split(fold=fold) train_dataset = dataset[tr_index] val_dataset = dataset[val_index] test_dataset = dataset[te_index] train_loader = DataLoader(train_dataset,batch_size=opt.batchSize, shuffle= True) val_loader = DataLoader(val_dataset, batch_size=opt.batchSize, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=False) ############### Define Graph Deep Learning Network ########################## model = Network(opt.indim,opt.ratio,opt.nclass).to(device) print(model) if opt_method == 'Adam': optimizer = torch.optim.Adam(model.parameters(), lr= opt.lr, weight_decay=opt.weightdecay) elif opt_method == 'SGD': optimizer = torch.optim.SGD(model.parameters(), lr =opt.lr, momentum = 0.9, weight_decay=opt.weightdecay, nesterov = True) scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.stepsize, gamma=opt.gamma) ############################### Define Other Loss Functions ######################################## def topk_loss(s,ratio): if ratio > 0.5: ratio = 1-ratio s = s.sort(dim=1).values res = -torch.log(s[:,-int(s.size(1)*ratio):]+EPS).mean() -torch.log(1-s[:,:int(s.size(1)*ratio)]+EPS).mean() return res def consist_loss(s): if len(s) == 0: return 0 s = torch.sigmoid(s) W = torch.ones(s.shape[0],s.shape[0]) D = torch.eye(s.shape[0])*torch.sum(W,dim=1) L = D-W L = L.to(device) res = torch.trace(torch.transpose(s,0,1) @ L @ s)/(s.shape[0]*s.shape[0]) return res ###################### Network Training Function##################################### def train(epoch): print('train...........') scheduler.step() for param_group in optimizer.param_groups: print("LR", param_group['lr']) model.train() s1_list = [] s2_list = [] loss_all = 0 step = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() output, w1, w2, s1, s2 = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos) s1_list.append(s1.view(-1).detach().cpu().numpy()) s2_list.append(s2.view(-1).detach().cpu().numpy()) loss_c = F.nll_loss(output, data.y) loss_p1 = (torch.norm(w1, p=2)-1) ** 2 loss_p2 = (torch.norm(w2, p=2)-1) ** 2 loss_tpk1 = topk_loss(s1,opt.ratio) loss_tpk2 = topk_loss(s2,opt.ratio) loss_consist = 0 for c in range(opt.nclass): loss_consist += consist_loss(s1[data.y == c]) loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \ + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist writer.add_scalar('train/classification_loss', loss_c, epoch*len(train_loader)+step) writer.add_scalar('train/unit_loss1', loss_p1, epoch*len(train_loader)+step) writer.add_scalar('train/unit_loss2', loss_p2, epoch*len(train_loader)+step) writer.add_scalar('train/TopK_loss1', loss_tpk1, epoch*len(train_loader)+step) writer.add_scalar('train/TopK_loss2', loss_tpk2, epoch*len(train_loader)+step) writer.add_scalar('train/GCL_loss', loss_consist, epoch*len(train_loader)+step) step = step + 1 loss.backward() loss_all += loss.item() * data.num_graphs optimizer.step() s1_arr = np.hstack(s1_list) s2_arr = np.hstack(s2_list) return loss_all / len(train_dataset), s1_arr, s2_arr ,w1,w2 ###################### Network Testing Function##################################### def test_acc(loader): model.eval() correct = 0 for data in loader: data = data.to(device) outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) pred = outputs[0].max(dim=1)[1] correct += pred.eq(data.y).sum().item() return correct / len(loader.dataset) def test_loss(loader,epoch): print('testing...........') model.eval() loss_all = 0 for data in loader: data = data.to(device) output, w1, w2, s1, s2= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) loss_c = F.nll_loss(output, data.y) loss_p1 = (torch.norm(w1, p=2)-1) ** 2 loss_p2 = (torch.norm(w2, p=2)-1) ** 2 loss_tpk1 = topk_loss(s1,opt.ratio) loss_tpk2 = topk_loss(s2,opt.ratio) loss_consist = 0 for c in range(opt.nclass): loss_consist += consist_loss(s1[data.y == c]) loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \ + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist loss_all += loss.item() * data.num_graphs return loss_all / len(loader.dataset) ####################################################################################### ############################ Model Training ######################################### ####################################################################################### best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 1e10 for epoch in range(0, num_epoch): since = time.time() tr_loss, s1_arr, s2_arr, w1, w2 = train(epoch) tr_acc = test_acc(train_loader) val_acc = test_acc(val_loader) val_loss = test_loss(val_loader,epoch) time_elapsed = time.time() - since print('*====**') print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Epoch: {:03d}, Train Loss: {:.7f}, ' 'Train Acc: {:.7f}, Test Loss: {:.7f}, Test Acc: {:.7f}'.format(epoch, tr_loss, tr_acc, val_loss, val_acc)) writer.add_scalars('Acc',{'train_acc':tr_acc,'val_acc':val_acc}, epoch) writer.add_scalars('Loss', {'train_loss': tr_loss, 'val_loss': val_loss}, epoch) writer.add_histogram('Hist/hist_s1', s1_arr, epoch) writer.add_histogram('Hist/hist_s2', s2_arr, epoch) if val_loss < best_loss and epoch > 5: print("saving best model") best_loss = val_loss best_model_wts = copy.deepcopy(model.state_dict()) if save_model: torch.save(best_model_wts, os.path.join(opt.save_path,str(fold)+'.pth')) ####################################################################################### ######################### Testing on testing set ###################################### ####################################################################################### if opt.load_model: model = Network(opt.indim,opt.ratio,opt.nclass).to(device) model.load_state_dict(torch.load(os.path.join(opt.save_path,str(fold)+'.pth'))) model.eval() preds = [] correct = 0 for data in val_loader: data = data.to(device) outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos) pred = outputs[0].max(1)[1] preds.append(pred.cpu().detach().numpy()) correct += pred.eq(data.y).sum().item() preds = np.concatenate(preds,axis=0) trues = val_dataset.data.y.cpu().detach().numpy() cm = confusion_matrix(trues,preds) print("Confusion matrix") print(classification_report(trues, preds)) else: model.load_state_dict(best_model_wts) model.eval() test_accuracy = test_acc(test_loader) test_l= test_loss(test_loader,0) print("===========================") print("Test Acc: {:.7f}, Test Loss: {:.7f} ".format(test_accuracy, test_l)) print(opt) ================================================ FILE: README.md ================================================ # Graph Neural Network for Brain Network Analysis A preliminary implementation of BrainGNN. The example presented here is on the public resting-state fMRI ABIDE for the convenience of development. This dataset was different from the ones used in our publication, which are cleaner task-fMRI. Still seeking solutions improve representation learning on the noisy data. ## Usage ### Setup **pip** See the `requirements.txt` for environment configuration. ```bash pip install -r requirements.txt ``` **PYG** To install pyg library, [please refer to the document](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) ### Dataset **ABIDE** We treat each fMRI as a brain graph. How to download and construct the graphs? ``` python 01-fetch_data.py python 02-process_data.py ``` ### How to run classification? Training and testing are integrated in file `main.py`. To run ``` python 03-main.py ``` ## Citation If you find the code and dataset useful, please cite our paper. ```latex @article{li2020braingnn, title={Braingnn: Interpretable brain graph neural network for fmri analysis}, author={Li, Xiaoxiao and Zhou,Yuan and Dvornek, Nicha and Zhang, Muhan and Gao, Siyuan and Zhuang, Juntang and Scheinost, Dustin and Staib, Lawrence and Ventola, Pamela and Duncan, James}, journal={bioRxiv}, year={2020}, publisher={Cold Spring Harbor Laboratory} } ``` ================================================ FILE: data/subject_ID.txt ================================================ 50128 51203 50325 50117 50573 50741 50779 51009 50746 50574 50110 50322 51036 51204 50119 50126 50314 51490 50784 51464 51000 51038 50748 51235 51007 51463 50783 50777 50313 50121 51053 51261 50723 50511 51295 50347 50982 50976 51098 51292 50340 50516 50724 51266 51054 50186 50529 50985 50520 50376 50978 50144 51096 50382 51250 51062 50349 51065 50385 51257 50143 51091 50371 50527 51268 50188 50518 50749 51039 50776 50120 50312 51006 51234 50782 51462 50118 51465 50785 51001 50315 50127 51491 51008 50778 51205 50575 50747 50111 50129 50116 50324 50740 50572 51030 51202 50370 50142 51090 50526 51256 51064 50519 50189 51269 51063 50383 51251 50521 50145 51097 50979 50377 50348 51055 50187 51267 51293 50341 50725 51258 50984 50528 50970 50510 50722 51294 50346 51260 51052 51099 50977 50379 50983 50039 50496 51312 50234 50006 50650 50802 50668 51118 50657 50233 51127 51315 50491 50008 50498 50037 50205 50661 51581 50453 50695 51575 51111 51323 51129 50659 51324 51116 51572 50692 50666 50202 50030 51142 51370 50269 51189 50251 50407 50438 51348 50603 50267 51187 50055 51341 50293 51173 51174 51346 50294 51180 50052 50260 50604 50436 50658 51128 50667 50455 50031 50203 51117 51325 50693 51573 50499 50009 51574 50694 51322 51110 50204 50036 51580 50660 50803 50669 51314 51126 50490 50656 50232 50038 50804 50007 50235 50651 50463 50497 51121 51313 50261 51181 50053 50437 50605 51347 50295 51175 50408 51172 51340 50292 50602 51186 50054 50266 50259 50250 50406 51349 50439 50257 51188 50268 51195 50047 50275 50611 51161 51353 50281 51159 51354 50286 51166 50424 50616 50272 51192 50040 50049 51362 51150 50412 50620 50618 50288 51168 50627 50415 50243 51365 50441 50217 50025 50819 51331 51103 51567 50687 50826 51558 51560 51104 51336 50022 50210 50446 51309 50821 51132 51300 51556 50642 50470 50014 51569 50689 50817 50013 50477 50645 50483 51307 51135 50448 51338 51169 50289 50619 51364 51156 50414 50626 50242 50048 50245 50621 50413 51151 51363 50628 50617 50425 51193 50041 50273 51167 51355 50287 51352 50280 51160 50274 51194 50046 50422 50610 50482 51134 51306 50012 50644 51339 50449 50643 50015 51301 51133 50485 51557 50816 50688 51568 50211 50023 50447 51561 51105 50820 51308 51102 51330 50686 51566 50440 50818 50024 50216 51559 50169 50955 50156 51084 50364 50700 50532 51070 50390 51048 50952 50738 50397 51077 50999 50707 50363 51083 50990 50158 50964 51273 51041 50193 50355 50167 50503 50731 50709 50399 51079 50997 50736 50504 50160 51280 50352 51046 50194 51274 51482 50306 50134 51220 51012 51476 50796 50339 50791 51471 51015 51227 50133 50301 50557 51485 51218 50568 51023 51211 50753 50561 50105 50337 51478 50798 50308 50330 50102 50566 50754 51216 51024 50559 51229 50996 51078 50962 50708 51275 51047 50195 50505 50737 51281 50353 50161 50965 50159 50991 50166 50354 50730 50502 51040 50192 51272 50739 51049 50706 50150 51082 50362 50998 51076 50954 50168 50391 51071 50365 50157 51085 50701 51025 51217 50103 50331 50755 50567 51228 50558 50560 50752 50336 50104 51210 50799 51479 50300 50132 50556 51484 51470 50790 51226 51014 50569 51219 51013 51221 50797 51477 50551 51483 50135 50307 50338 50171 50343 51291 50727 50515 50185 51057 51265 50972 50388 50986 51068 51262 50182 51050 51606 50344 51296 50981 50149 51254 50386 50988 51066 50372 50524 51059 50711 50523 51095 50147 50375 51061 51253 50381 51298 51238 50577 50745 50321 50113 51207 51035 51469 50789 50319 51456 51032 50114 50326 50742 50570 51209 51236 50780 51460 50774 50122 50310 51458 50317 50125 51493 50773 51467 50787 51231 51003 51252 50380 51060 50710 50374 51094 50146 51299 51093 50373 50525 51067 50989 51255 50387 50728 51058 50345 51297 50183 51051 51263 51607 50148 50974 51264 50184 51056 50342 50170 50514 50726 51069 50987 50973 51459 50329 50786 51466 51002 51230 50124 50316 50772 51492 50578 51208 50775 50311 50123 51237 51461 50781 50318 50788 51468 50327 50115 50571 50743 51457 51201 51033 51239 51034 51206 50744 50576 50112 50320 50060 50252 50404 51146 50609 50299 51179 51373 51141 50403 50255 50058 50297 51345 51177 50263 50051 51183 50435 50607 51148 50056 51184 50264 51170 50290 51342 50801 51329 50466 50654 51316 51124 50492 51578 50698 50208 51123 51311 50005 50237 50653 51318 50468 51327 50691 51571 50665 51585 50033 50201 50239 50206 50034 51582 51576 50696 51320 51112 50291 51343 51171 50433 50601 50265 50057 51185 50050 51182 50262 50606 50434 50296 51344 51149 50402 50254 51140 50059 51147 50253 50405 51178 50298 50608 50697 51577 51113 51321 50035 50207 50663 51583 50469 51319 51584 50664 50200 50032 51326 51114 51570 50690 50807 50209 50699 51579 50236 50004 50652 50494 51122 51328 50800 51317 50493 50655 50467 50003 51563 50683 51335 51107 50213 50445 51138 50648 50822 50442 50026 50214 51100 51332 51564 50019 50825 50489 50010 50646 50480 51136 51304 51109 51303 51131 50487 50017 50028 50814 50418 51165 50285 51357 50615 50427 50043 51191 50271 50249 50276 50044 51196 50612 50282 51350 51162 51359 50416 50624 50240 51154 50278 51198 51153 51361 50247 50623 50411 50016 51130 51302 50486 50815 50029 50481 51305 51137 50011 50647 50812 51333 51101 51565 50685 50443 50215 50027 50488 50824 50020 50212 50444 50682 51562 51106 51334 50649 50823 51139 51199 50279 50246 50410 50622 51360 51152 51358 50428 51155 50625 50417 50241 50248 51163 50283 51351 50045 51197 50277 50613 50421 50419 51369 50426 50614 50270 50042 51190 50284 51356 51164 51472 50792 51224 51016 50302 50130 51486 50554 51029 51481 50553 50305 51011 51223 50795 50333 50757 50565 51027 51215 51488 51018 51212 51020 50562 50750 50334 50106 51279 50199 50509 51074 50704 51080 50152 50360 50956 50358 50367 51087 50969 50531 50703 51241 51073 50960 50994 51248 50507 50735 50351 50163 51277 50197 51045 50993 50369 51089 50967 50190 51042 50164 50958 50356 50732 50500 50751 50563 50107 50335 51021 51213 51214 51026 50332 50564 50756 51019 51489 51222 51010 51474 50794 51480 50552 50304 50136 50109 50131 50303 51487 50555 50793 51473 51017 51225 51028 50966 51088 50368 50992 50357 50959 50501 50733 51271 50191 51249 50995 50961 50196 51044 51276 50162 50350 51282 50359 50957 51072 51240 50968 51086 50366 50702 50530 50198 51278 50705 50361 51081 50153 51075 ================================================ FILE: imports/ABIDEDataset.py ================================================ import torch from torch_geometric.data import InMemoryDataset,Data from os.path import join, isfile from os import listdir import numpy as np import os.path as osp from imports.read_abide_stats_parall import read_data class ABIDEDataset(InMemoryDataset): def __init__(self, root, name, transform=None, pre_transform=None): self.root = root self.name = name super(ABIDEDataset, self).__init__(root,transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_file_names(self): data_dir = osp.join(self.root,'raw') onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] onlyfiles.sort() return onlyfiles @property def processed_file_names(self): return 'data.pt' def download(self): # Download to `self.raw_dir`. return def process(self): # Read data into huge `Data` list. self.data, self.slices = read_data(self.raw_dir) if self.pre_filter is not None: data_list = [self.get(idx) for idx in range(len(self))] data_list = [data for data in data_list if self.pre_filter(data)] self.data, self.slices = self.collate(data_list) if self.pre_transform is not None: data_list = [self.get(idx) for idx in range(len(self))] data_list = [self.pre_transform(data) for data in data_list] self.data, self.slices = self.collate(data_list) torch.save((self.data, self.slices), self.processed_paths[0]) def __repr__(self): return '{}({})'.format(self.name, len(self)) ================================================ FILE: imports/__inits__.py ================================================ ================================================ FILE: imports/gdc.py ================================================ import torch import numba import numpy as np from scipy.linalg import expm from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj from torch_sparse import coalesce from torch_scatter import scatter_add def jit(): def decorator(func): try: return numba.jit(cache=True)(func) except RuntimeError: return numba.jit(cache=False)(func) return decorator class GDC(object): r"""Processes the graph via Graph Diffusion Convolution (GDC) from the `"Diffusion Improves Graph Learning" `_ paper. .. note:: The paper offers additional advice on how to choose the hyperparameters. For an example of using GCN with GDC, see `examples/gcn.py `_. Args: self_loop_weight (float, optional): Weight of the added self-loop. Set to :obj:`None` to add no self-loops. (default: :obj:`1`) normalization_in (str, optional): Normalization of the transition matrix on the original (input) graph. Possible values: :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`. See :func:`GDC.transition_matrix` for details. (default: :obj:`"sym"`) normalization_out (str, optional): Normalization of the transition matrix on the transformed GDC (output) graph. Possible values: :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`. See :func:`GDC.transition_matrix` for details. (default: :obj:`"col"`) diffusion_kwargs (dict, optional): Dictionary containing the parameters for diffusion. `method` specifies the diffusion method (:obj:`"ppr"`, :obj:`"heat"` or :obj:`"coeff"`). Each diffusion method requires different additional parameters. See :func:`GDC.diffusion_matrix_exact` or :func:`GDC.diffusion_matrix_approx` for details. (default: :obj:`dict(method='ppr', alpha=0.15)`) sparsification_kwargs (dict, optional): Dictionary containing the parameters for sparsification. `method` specifies the sparsification method (:obj:`"threshold"` or :obj:`"topk"`). Each sparsification method requires different additional parameters. See :func:`GDC.sparsify_dense` for details. (default: :obj:`dict(method='threshold', avg_degree=64)`) exact (bool, optional): Whether to exactly calculate the diffusion matrix. Note that the exact variants are not scalable. They densify the adjacency matrix and calculate either its inverse or its matrix exponential. However, the approximate variants do not support edge weights and currently only personalized PageRank and sparsification by threshold are implemented as fast, approximate versions. (default: :obj:`True`) :rtype: :class:`torch_geometric.data.Data` """ def __init__(self, self_loop_weight=1, normalization_in='sym', normalization_out='col', diffusion_kwargs=dict(method='ppr', alpha=0.15), sparsification_kwargs=dict(method='threshold', avg_degree=64), exact=True): self.self_loop_weight = self_loop_weight self.normalization_in = normalization_in self.normalization_out = normalization_out self.diffusion_kwargs = diffusion_kwargs self.sparsification_kwargs = sparsification_kwargs self.exact = exact if self_loop_weight: assert exact or self_loop_weight == 1 @torch.no_grad() def __call__(self, data): N = data.num_nodes edge_index = data.edge_index if data.edge_attr is None: edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) else: edge_weight = data.edge_attr assert self.exact assert edge_weight.dim() == 1 if self.self_loop_weight: edge_index, edge_weight = add_self_loops( edge_index, edge_weight, fill_value=self.self_loop_weight, num_nodes=N) edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) if self.exact: edge_index, edge_weight = self.transition_matrix( edge_index, edge_weight, N, self.normalization_in) diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N, **self.diffusion_kwargs) edge_index, edge_weight = self.sparsify_dense( diff_mat, **self.sparsification_kwargs) else: edge_index, edge_weight = self.diffusion_matrix_approx( edge_index, edge_weight, N, self.normalization_in, **self.diffusion_kwargs) edge_index, edge_weight = self.sparsify_sparse( edge_index, edge_weight, N, **self.sparsification_kwargs) edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) edge_index, edge_weight = self.transition_matrix( edge_index, edge_weight, N, self.normalization_out) data.edge_index = edge_index data.edge_attr = edge_weight return data def transition_matrix(self, edge_index, edge_weight, num_nodes, normalization): r"""Calculate the approximate, sparse diffusion on a given sparse matrix. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. normalization (str): Normalization scheme: 1. :obj:`"sym"`: Symmetric normalization :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}`. 2. :obj:`"col"`: Column-wise normalization :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`. 3. :obj:`"row"`: Row-wise normalization :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`. 4. :obj:`None`: No normalization. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ if normalization == 'sym': row, col = edge_index deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] elif normalization == 'col': _, col = edge_index deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) deg_inv = 1. / deg deg_inv[deg_inv == float('inf')] = 0 edge_weight = edge_weight * deg_inv[col] elif normalization == 'row': row, _ = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv = 1. / deg deg_inv[deg_inv == float('inf')] = 0 edge_weight = edge_weight * deg_inv[row] elif normalization is None: pass else: raise ValueError( 'Transition matrix normalization {} unknown.'.format( normalization)) return edge_index, edge_weight def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes, method, **kwargs): r"""Calculate the (dense) diffusion on a given sparse graph. Note that these exact variants are not scalable. They densify the adjacency matrix and calculate either its inverse or its matrix exponential. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. method (str): Diffusion method: 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. Additionally expects the parameter: - **alpha** (*float*) - Return probability in PPR. Commonly lies in :obj:`[0.05, 0.2]`. 2. :obj:`"heat"`: Use heat kernel diffusion. Additionally expects the parameter: - **t** (*float*) - Time of diffusion. Commonly lies in :obj:`[2, 10]`. 3. :obj:`"coeff"`: Freely choose diffusion coefficients. Additionally expects the parameter: - **coeffs** (*List[float]*) - List of coefficients :obj:`theta_k` for each power of the transition matrix (starting at :obj:`0`). :rtype: (:class:`Tensor`) """ if method == 'ppr': # α (I_n + (α - 1) A)^-1 edge_weight = (kwargs['alpha'] - 1) * edge_weight edge_index, edge_weight = add_self_loops(edge_index, edge_weight, fill_value=1, num_nodes=num_nodes) mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() diff_matrix = kwargs['alpha'] * torch.inverse(mat) elif method == 'heat': # exp(t (A - I_n)) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, fill_value=-1, num_nodes=num_nodes) edge_weight = kwargs['t'] * edge_weight mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() undirected = is_undirected(edge_index, edge_weight, num_nodes) diff_matrix = self.__expm__(mat, undirected) elif method == 'coeff': adj_matrix = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze() mat = torch.eye(num_nodes, device=edge_index.device) diff_matrix = kwargs['coeffs'][0] * mat for coeff in kwargs['coeffs'][1:]: mat = mat @ adj_matrix diff_matrix += coeff * mat else: raise ValueError('Exact GDC diffusion {} unknown.'.format(method)) return diff_matrix def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes, normalization, method, **kwargs): r"""Calculate the approximate, sparse diffusion on a given sparse graph. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. normalization (str): Transition matrix normalization scheme (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`). See :func:`GDC.transition_matrix` for details. method (str): Diffusion method: 1. :obj:`"ppr"`: Use personalized PageRank as diffusion. Additionally expects the parameters: - **alpha** (*float*) - Return probability in PPR. Commonly lies in :obj:`[0.05, 0.2]`. - **eps** (*float*) - Threshold for PPR calculation stopping criterion (:obj:`edge_weight >= eps * out_degree`). Recommended default: :obj:`1e-4`. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ if method == 'ppr': if normalization == 'sym': # Calculate original degrees. _, col = edge_index deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) edge_index_np = edge_index.cpu().numpy() # Assumes coalesced edge_index. _, indptr, out_degree = np.unique(edge_index_np[0], return_index=True, return_counts=True) neighbors, neighbor_weights = GDC.__calc_ppr__( indptr, edge_index_np[1], out_degree, kwargs['alpha'], kwargs['eps']) ppr_normalization = 'col' if normalization == 'col' else 'row' edge_index, edge_weight = self.__neighbors_to_graph__( neighbors, neighbor_weights, ppr_normalization, device=edge_index.device) edge_index = edge_index.to(torch.long) if normalization == 'sym': # We can change the normalization from row-normalized to # symmetric by multiplying the resulting matrix with D^{1/2} # from the left and D^{-1/2} from the right. # Since we use the original degrees for this it will be like # we had used symmetric normalization from the beginning # (except for errors due to approximation). row, col = edge_index deg_inv = deg.sqrt() deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col] elif normalization in ['col', 'row']: pass else: raise ValueError( ('Transition matrix normalization {} not implemented for ' 'non-exact GDC computation.').format(normalization)) elif method == 'heat': raise NotImplementedError( ('Currently no fast heat kernel is implemented. You are ' 'welcome to create one yourself, e.g., based on ' '"Kloster and Gleich: Heat kernel based community detection ' '(KDD 2014)."')) else: raise ValueError( 'Approximate GDC diffusion {} unknown.'.format(method)) return edge_index, edge_weight def sparsify_dense(self, matrix, method, **kwargs): r"""Sparsifies the given dense matrix. Args: matrix (Tensor): Matrix to sparsify. num_nodes (int): Number of nodes. method (str): Method of sparsification. Options: 1. :obj:`"threshold"`: Remove all edges with weights smaller than :obj:`eps`. Additionally expects one of these parameters: - **eps** (*float*) - Threshold to bound edges at. - **avg_degree** (*int*) - If :obj:`eps` is not given, it can optionally be calculated by calculating the :obj:`eps` required to achieve a given :obj:`avg_degree`. 2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per node (column). Additionally expects the following parameters: - **k** (*int*) - Specifies the number of edges to keep. - **dim** (*int*) - The axis along which to take the top :obj:`k`. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ assert matrix.shape[0] == matrix.shape[1] N = matrix.shape[1] if method == 'threshold': if 'eps' not in kwargs.keys(): kwargs['eps'] = self.__calculate_eps__(matrix, N, kwargs['avg_degree']) edge_index = torch.nonzero(matrix >= kwargs['eps']).t() edge_index_flat = edge_index[0] * N + edge_index[1] edge_weight = matrix.flatten()[edge_index_flat] elif method == 'topk': assert kwargs['dim'] in [0, 1] sort_idx = torch.argsort(matrix, dim=kwargs['dim'], descending=True) if kwargs['dim'] == 0: top_idx = sort_idx[:kwargs['k']] edge_weight = torch.gather(matrix, dim=kwargs['dim'], index=top_idx).flatten() row_idx = torch.arange(0, N, device=matrix.device).repeat( kwargs['k']) edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0) else: top_idx = sort_idx[:, :kwargs['k']] edge_weight = torch.gather(matrix, dim=kwargs['dim'], index=top_idx).flatten() col_idx = torch.arange( 0, N, device=matrix.device).repeat_interleave(kwargs['k']) edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0) else: raise ValueError('GDC sparsification {} unknown.'.format(method)) return edge_index, edge_weight def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method, **kwargs): r"""Sparsifies a given sparse graph further. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor): One-dimensional edge weights. num_nodes (int): Number of nodes. method (str): Method of sparsification: 1. :obj:`"threshold"`: Remove all edges with weights smaller than :obj:`eps`. Additionally expects one of these parameters: - **eps** (*float*) - Threshold to bound edges at. - **avg_degree** (*int*) - If :obj:`eps` is not given, it can optionally be calculated by calculating the :obj:`eps` required to achieve a given :obj:`avg_degree`. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ if method == 'threshold': if 'eps' not in kwargs.keys(): kwargs['eps'] = self.__calculate_eps__(edge_weight, num_nodes, kwargs['avg_degree']) remaining_edge_idx = torch.nonzero( edge_weight >= kwargs['eps']).flatten() edge_index = edge_index[:, remaining_edge_idx] edge_weight = edge_weight[remaining_edge_idx] elif method == 'topk': raise NotImplementedError( 'Sparse topk sparsification not implemented.') else: raise ValueError('GDC sparsification {} unknown.'.format(method)) return edge_index, edge_weight def __expm__(self, matrix, symmetric): r"""Calculates matrix exponential. Args: matrix (Tensor): Matrix to take exponential of. symmetric (bool): Specifies whether the matrix is symmetric. :rtype: (:class:`Tensor`) """ if symmetric: e, V = torch.symeig(matrix, eigenvectors=True) diff_mat = V @ torch.diag(e.exp()) @ V.t() else: diff_mat_np = expm(matrix.cpu().numpy()) diff_mat = torch.Tensor(diff_mat_np).to(matrix.device) return diff_mat def __calculate_eps__(self, matrix, num_nodes, avg_degree): r"""Calculates threshold necessary to achieve a given average degree. Args: matrix (Tensor): Adjacency matrix or edge weights. num_nodes (int): Number of nodes. avg_degree (int): Target average degree. :rtype: (:class:`float`) """ sorted_edges = torch.sort(matrix.flatten(), descending=True).values if avg_degree * num_nodes > len(sorted_edges): return -np.inf return sorted_edges[avg_degree * num_nodes - 1] def __neighbors_to_graph__(self, neighbors, neighbor_weights, normalization='row', device='cpu'): r"""Combine a list of neighbors and neighbor weights to create a sparse graph. Args: neighbors (List[List[int]]): List of neighbors for each node. neighbor_weights (List[List[float]]): List of weights for the neighbors of each node. normalization (str): Normalization of resulting matrix (options: :obj:`"row"`, :obj:`"col"`). (default: :obj:`"row"`) device (torch.device): Device to create output tensors on. (default: :obj:`"cpu"`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ edge_weight = torch.Tensor(np.concatenate(neighbor_weights)).to(device) i = np.repeat(np.arange(len(neighbors)), np.fromiter(map(len, neighbors), dtype=np.int)) j = np.concatenate(neighbors) if normalization == 'col': edge_index = torch.Tensor(np.vstack([j, i])).to(device) N = len(neighbors) edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N) elif normalization == 'row': edge_index = torch.Tensor(np.vstack([i, j])).to(device) else: raise ValueError( f"PPR matrix normalization {normalization} unknown.") return edge_index, edge_weight @staticmethod @jit() def __calc_ppr__(indptr, indices, out_degree, alpha, eps): r"""Calculate the personalized PageRank vector for all nodes using a variant of the Andersen algorithm (see Andersen et al. :Local Graph Partitioning using PageRank Vectors.) Args: indptr (np.ndarray): Index pointer for the sparse matrix (CSR-format). indices (np.ndarray): Indices of the sparse matrix entries (CSR-format). out_degree (np.ndarray): Out-degree of each node. alpha (float): Alpha of the PageRank to calculate. eps (float): Threshold for PPR calculation stopping criterion (:obj:`edge_weight >= eps * out_degree`). :rtype: (:class:`List[List[int]]`, :class:`List[List[float]]`) """ alpha_eps = alpha * eps js = [] vals = [] for inode in range(len(out_degree)): p = {inode: 0.0} r = {} r[inode] = alpha q = [inode] while len(q) > 0: unode = q.pop() res = r[unode] if unode in r else 0 if unode in p: p[unode] += res else: p[unode] = res r[unode] = 0 for vnode in indices[indptr[unode]:indptr[unode + 1]]: _val = (1 - alpha) * res / out_degree[unode] if vnode in r: r[vnode] += _val else: r[vnode] = _val res_vnode = r[vnode] if vnode in r else 0 if res_vnode >= alpha_eps * out_degree[vnode]: if vnode not in q: q.append(vnode) js.append(list(p.keys())) vals.append(list(p.values())) return js, vals def __repr__(self): return '{}()'.format(self.__class__.__name__) ================================================ FILE: imports/preprocess_data.py ================================================ # Copyright (c) 2019 Mwiza Kunda # Copyright (C) 2017 Sarah Parisot , Sofia Ira Ktena # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implcd ied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import os import warnings import glob import csv import re import numpy as np import scipy.io as sio import sys from nilearn import connectome import pandas as pd from scipy.spatial import distance from scipy import signal from sklearn.compose import ColumnTransformer from sklearn.preprocessing import Normalizer from sklearn.preprocessing import OrdinalEncoder from sklearn.preprocessing import OneHotEncoder from sklearn.preprocessing import StandardScaler warnings.filterwarnings("ignore") # Input data variables root_folder = '/home/azureuser/projects/BrainGNN/data/' data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal') phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv') def fetch_filenames(subject_IDs, file_type, atlas): """ subject_list : list of short subject IDs in string format file_type : must be one of the available file types filemapping : resulting file name format returns: filenames : list of filetypes (same length as subject_list) """ filemapping = {'func_preproc': '_func_preproc.nii.gz', 'rois_' + atlas: '_rois_' + atlas + '.1D'} # The list to be filled filenames = [] # Fill list with requested file paths for i in range(len(subject_IDs)): os.chdir(data_folder) try: try: os.chdir(data_folder) filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) except: os.chdir(data_folder + '/' + subject_IDs[i]) filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0]) except IndexError: filenames.append('N/A') return filenames # Get timeseries arrays for list of subjects def get_timeseries(subject_list, atlas_name, silence=False): """ subject_list : list of short subject IDs in string format atlas_name : the atlas based on which the timeseries are generated e.g. aal, cc200 returns: time_series : list of timeseries arrays, each of shape (timepoints x regions) """ timeseries = [] for i in range(len(subject_list)): subject_folder = os.path.join(data_folder, subject_list[i]) ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')] fl = os.path.join(subject_folder, ro_file[0]) if silence != True: print("Reading timeseries file %s" % fl) timeseries.append(np.loadtxt(fl, skiprows=0)) return timeseries # compute connectivity matrices def subject_connectivity(timeseries, subjects, atlas_name, kind, iter_no='', seed=1234, n_subjects='', save=True, save_path=data_folder): """ timeseries : timeseries table for subject (timepoints x regions) subjects : subject IDs atlas_name : name of the parcellation atlas used kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation iter_no : tangent connectivity iteration number for cross validation evaluation save : save the connectivity matrix to a file save_path : specify path to save the matrix if different from subject folder returns: connectivity : connectivity matrix (regions x regions) """ if kind in ['TPE', 'TE', 'correlation','partial correlation']: if kind not in ['TPE', 'TE']: conn_measure = connectome.ConnectivityMeasure(kind=kind) connectivity = conn_measure.fit_transform(timeseries) else: if kind == 'TPE': conn_measure = connectome.ConnectivityMeasure(kind='correlation') conn_mat = conn_measure.fit_transform(timeseries) conn_measure = connectome.ConnectivityMeasure(kind='tangent') connectivity_fit = conn_measure.fit(conn_mat) connectivity = connectivity_fit.transform(conn_mat) else: conn_measure = connectome.ConnectivityMeasure(kind='tangent') connectivity_fit = conn_measure.fit(timeseries) connectivity = connectivity_fit.transform(timeseries) if save: if kind not in ['TPE', 'TE']: for i, subj_id in enumerate(subjects): subject_file = os.path.join(save_path, subj_id, subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat') sio.savemat(subject_file, {'connectivity': connectivity[i]}) return connectivity else: for i, subj_id in enumerate(subjects): subject_file = os.path.join(save_path, subj_id, subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str( iter_no) + '_' + str(seed) + '_' + validation_ext + str( n_subjects) + '.mat') sio.savemat(subject_file, {'connectivity': connectivity[i]}) return connectivity_fit # Get the list of subject IDs def get_ids(num_subjects=None): """ return: subject_IDs : list of all subject IDs """ subject_IDs = np.genfromtxt(os.path.join(data_folder, 'subject_IDs.txt'), dtype=str) if num_subjects is not None: subject_IDs = subject_IDs[:num_subjects] return subject_IDs # Get phenotype values for a list of subjects def get_subject_score(subject_list, score): scores_dict = {} with open(phenotype) as csv_file: reader = csv.DictReader(csv_file) for row in reader: if row['SUB_ID'] in subject_list: if score == 'HANDEDNESS_CATEGORY': if (row[score].strip() == '-9999') or (row[score].strip() == ''): scores_dict[row['SUB_ID']] = 'R' elif row[score] == 'Mixed': scores_dict[row['SUB_ID']] = 'Ambi' elif row[score] == 'L->R': scores_dict[row['SUB_ID']] = 'Ambi' else: scores_dict[row['SUB_ID']] = row[score] elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'): if (row[score].strip() == '-9999') or (row[score].strip() == ''): scores_dict[row['SUB_ID']] = 100 else: scores_dict[row['SUB_ID']] = float(row[score]) else: scores_dict[row['SUB_ID']] = row[score] return scores_dict # preprocess phenotypes. Categorical -> ordinal representation def preprocess_phenotypes(pheno_ft, params): if params['model'] == 'MIDA': ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough') else: ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough') pheno_ft = ct.fit_transform(pheno_ft) pheno_ft = pheno_ft.astype('float32') return (pheno_ft) # create phenotype feature vector to concatenate with fmri feature vectors def phenotype_ft_vector(pheno_ft, num_subjects, params): gender = pheno_ft[:, 0] if params['model'] == 'MIDA': eye = pheno_ft[:, 0] hand = pheno_ft[:, 2] age = pheno_ft[:, 3] fiq = pheno_ft[:, 4] else: eye = pheno_ft[:, 2] hand = pheno_ft[:, 3] age = pheno_ft[:, 4] fiq = pheno_ft[:, 5] phenotype_ft = np.zeros((num_subjects, 4)) phenotype_ft_eye = np.zeros((num_subjects, 2)) phenotype_ft_hand = np.zeros((num_subjects, 3)) for i in range(num_subjects): phenotype_ft[i, int(gender[i])] = 1 phenotype_ft[i, -2] = age[i] phenotype_ft[i, -1] = fiq[i] phenotype_ft_eye[i, int(eye[i])] = 1 phenotype_ft_hand[i, int(hand[i])] = 1 if params['model'] == 'MIDA': phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1) else: phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1) return phenotype_ft # Load precomputed fMRI connectivity networks def get_networks(subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal", variable='connectivity'): """ subject_list : list of subject IDs kind : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation atlas_name : name of the parcellation atlas used variable : variable name in the .mat file that has been used to save the precomputed networks return: matrix : feature matrix of connectivity networks (num_subjects x network_size) """ all_networks = [] for subject in subject_list: if len(kind.split()) == 2: kind = '_'.join(kind.split()) fl = os.path.join(data_folder, subject, subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat") matrix = sio.loadmat(fl)[variable] all_networks.append(matrix) if kind in ['TE', 'TPE']: norm_networks = [mat for mat in all_networks] else: norm_networks = [np.arctanh(mat) for mat in all_networks] networks = np.stack(norm_networks) return networks ================================================ FILE: imports/read_abide_stats_parall.py ================================================ ''' Author: Xiaoxiao Li Date: 2019/02/24 ''' import os.path as osp from os import listdir import os import glob import h5py import torch import numpy as np from scipy.io import loadmat from torch_geometric.data import Data import networkx as nx from networkx.convert_matrix import from_numpy_matrix import multiprocessing from torch_sparse import coalesce from torch_geometric.utils import remove_self_loops from functools import partial import deepdish as dd from imports.gdc import GDC def split(data, batch): node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0) node_slice = torch.cat([torch.tensor([0]), node_slice]) row, _ = data.edge_index edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0) edge_slice = torch.cat([torch.tensor([0]), edge_slice]) # Edge indices should start at zero for every graph. data.edge_index -= node_slice[batch[row]].unsqueeze(0) slices = {'edge_index': edge_slice} if data.x is not None: slices['x'] = node_slice if data.edge_attr is not None: slices['edge_attr'] = edge_slice if data.y is not None: if data.y.size(0) == batch.size(0): slices['y'] = node_slice else: slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long) if data.pos is not None: slices['pos'] = node_slice return data, slices def cat(seq): seq = [item for item in seq if item is not None] seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq] return torch.cat(seq, dim=-1).squeeze() if len(seq) > 0 else None class NoDaemonProcess(multiprocessing.Process): @property def daemon(self): return False @daemon.setter def daemon(self, value): pass class NoDaemonContext(type(multiprocessing.get_context())): Process = NoDaemonProcess def read_data(data_dir): onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))] onlyfiles.sort() batch = [] pseudo = [] y_list = [] edge_att_list, edge_index_list,att_list = [], [], [] # parallar computing cores = multiprocessing.cpu_count() pool = multiprocessing.Pool(processes=cores) #pool = MyPool(processes = cores) func = partial(read_sigle_data, data_dir) import timeit start = timeit.default_timer() res = pool.map(func, onlyfiles) pool.close() pool.join() stop = timeit.default_timer() print('Time: ', stop - start) for j in range(len(res)): edge_att_list.append(res[j][0]) edge_index_list.append(res[j][1]+j*res[j][4]) att_list.append(res[j][2]) y_list.append(res[j][3]) batch.append([j]*res[j][4]) pseudo.append(np.diag(np.ones(res[j][4]))) edge_att_arr = np.concatenate(edge_att_list) edge_index_arr = np.concatenate(edge_index_list, axis=1) att_arr = np.concatenate(att_list, axis=0) pseudo_arr = np.concatenate(pseudo, axis=0) y_arr = np.stack(y_list) edge_att_torch = torch.from_numpy(edge_att_arr.reshape(len(edge_att_arr), 1)).float() att_torch = torch.from_numpy(att_arr).float() y_torch = torch.from_numpy(y_arr).long() # classification batch_torch = torch.from_numpy(np.hstack(batch)).long() edge_index_torch = torch.from_numpy(edge_index_arr).long() pseudo_torch = torch.from_numpy(pseudo_arr).float() data = Data(x=att_torch, edge_index=edge_index_torch, y=y_torch, edge_attr=edge_att_torch, pos = pseudo_torch ) data, slices = split(data, batch_torch) return data, slices def read_sigle_data(data_dir,filename,use_gdc =False): temp = dd.io.load(osp.join(data_dir, filename)) # read edge and edge attribute pcorr = np.abs(temp['pcorr'][()]) num_nodes = pcorr.shape[0] G = from_numpy_matrix(pcorr) A = nx.to_scipy_sparse_matrix(G) adj = A.tocoo() edge_att = np.zeros(len(adj.row)) for i in range(len(adj.row)): edge_att[i] = pcorr[adj.row[i], adj.col[i]] edge_index = np.stack([adj.row, adj.col]) edge_index, edge_att = remove_self_loops(torch.from_numpy(edge_index), torch.from_numpy(edge_att)) edge_index = edge_index.long() edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes, num_nodes) att = temp['corr'][()] label = temp['label'][()] att_torch = torch.from_numpy(att).float() y_torch = torch.from_numpy(np.array(label)).long() # classification data = Data(x=att_torch, edge_index=edge_index.long(), y=y_torch, edge_attr=edge_att) if use_gdc: ''' Implementation of https://papers.nips.cc/paper/2019/hash/23c894276a2c5a16470e6a31f4618d73-Abstract.html ''' data.edge_attr = data.edge_attr.squeeze() gdc = GDC(self_loop_weight=1, normalization_in='sym', normalization_out='col', diffusion_kwargs=dict(method='ppr', alpha=0.2), sparsification_kwargs=dict(method='topk', k=20, dim=0), exact=True) data = gdc(data) return data.edge_attr.data.numpy(),data.edge_index.data.numpy(),data.x.data.numpy(),data.y.data.item(),num_nodes else: return edge_att.data.numpy(),edge_index.data.numpy(),att,label,num_nodes if __name__ == "__main__": data_dir = '/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal/raw' filename = '50346.h5' read_sigle_data(data_dir, filename) ================================================ FILE: imports/utils.py ================================================ from scipy import stats import matplotlib.pyplot as plt import numpy as np import torch from scipy.io import loadmat from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import KFold def train_val_test_split(kfold = 5, fold = 0): n_sub = 1035 id = list(range(n_sub)) import random random.seed(123) random.shuffle(id) kf = KFold(n_splits=kfold, random_state=123,shuffle = True) kf2 = KFold(n_splits=kfold-1, shuffle=True, random_state = 666) test_index = list() train_index = list() val_index = list() for tr,te in kf.split(np.array(id)): test_index.append(te) tr_id, val_id = list(kf2.split(tr))[0] train_index.append(tr[tr_id]) val_index.append(tr[val_id]) train_id = train_index[fold] test_id = test_index[fold] val_id = val_index[fold] return train_id,val_id,test_id ================================================ FILE: net/braingnn.py ================================================ import torch import torch.nn.functional as F import torch.nn as nn from torch_geometric.nn import TopKPooling from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp from torch_geometric.utils import (add_self_loops, sort_edge_index, remove_self_loops) from torch_sparse import spspmm from net.braingraphconv import MyNNConv ########################################################################################################################## class Network(torch.nn.Module): def __init__(self, indim, ratio, nclass, k=8, R=200): ''' :param indim: (int) node feature dimension :param ratio: (float) pooling ratio in (0,1) :param nclass: (int) number of classes :param k: (int) number of communities :param R: (int) number of ROIs ''' super(Network, self).__init__() self.indim = indim self.dim1 = 32 self.dim2 = 32 self.dim3 = 512 self.dim4 = 256 self.dim5 = 8 self.k = k self.R = R self.n1 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim1 * self.indim)) self.conv1 = MyNNConv(self.indim, self.dim1, self.n1, normalize=False) self.pool1 = TopKPooling(self.dim1, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid) self.n2 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim2 * self.dim1)) self.conv2 = MyNNConv(self.dim1, self.dim2, self.n2, normalize=False) self.pool2 = TopKPooling(self.dim2, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid) #self.fc1 = torch.nn.Linear((self.dim2) * 2, self.dim2) self.fc1 = torch.nn.Linear((self.dim1+self.dim2)*2, self.dim2) self.bn1 = torch.nn.BatchNorm1d(self.dim2) self.fc2 = torch.nn.Linear(self.dim2, self.dim3) self.bn2 = torch.nn.BatchNorm1d(self.dim3) self.fc3 = torch.nn.Linear(self.dim3, nclass) def forward(self, x, edge_index, batch, edge_attr, pos): x = self.conv1(x, edge_index, edge_attr, pos) x, edge_index, edge_attr, batch, perm, score1 = self.pool1(x, edge_index, edge_attr, batch) pos = pos[perm] x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) edge_attr = edge_attr.squeeze() edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0)) x = self.conv2(x, edge_index, edge_attr, pos) x, edge_index, edge_attr, batch, perm, score2 = self.pool2(x, edge_index,edge_attr, batch) x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = torch.cat([x1,x2], dim=1) x = self.bn1(F.relu(self.fc1(x))) x = F.dropout(x, p=0.5, training=self.training) x = self.bn2(F.relu(self.fc2(x))) x= F.dropout(x, p=0.5, training=self.training) x = F.log_softmax(self.fc3(x), dim=-1) return x,self.pool1.weight,self.pool2.weight, torch.sigmoid(score1).view(x.size(0),-1), torch.sigmoid(score2).view(x.size(0),-1) def augment_adj(self, edge_index, edge_weight, num_nodes): edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes) edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, num_nodes) edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, edge_weight, num_nodes, num_nodes, num_nodes) edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) return edge_index, edge_weight ================================================ FILE: net/braingraphconv.py ================================================ import torch import torch.nn.functional as F from torch.nn import Parameter from net.brainmsgpassing import MyMessagePassing from torch_geometric.utils import add_remaining_self_loops,softmax from torch_geometric.typing import (OptTensor) from net.inits import uniform class MyNNConv(MyMessagePassing): def __init__(self, in_channels, out_channels, nn, normalize=False, bias=True, **kwargs): super(MyNNConv, self).__init__(aggr='mean', **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.normalize = normalize self.nn = nn #self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): # uniform(self.in_channels, self.weight) uniform(self.in_channels, self.bias) def forward(self, x, edge_index, edge_weight=None, pseudo= None, size=None): """""" edge_weight = edge_weight.squeeze() if size is None and torch.is_tensor(x): edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, 1, x.size(0)) weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels) if torch.is_tensor(x): x = torch.matmul(x.unsqueeze(1), weight).squeeze(1) else: x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) # weight = self.nn(pseudo).view(-1, self.out_channels,self.in_channels) # if torch.is_tensor(x): # x = torch.matmul(x.unsqueeze(1), weight.permute(0,2,1)).squeeze(1) # else: # x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1), # None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1)) return self.propagate(edge_index, size=size, x=x, edge_weight=edge_weight) def message(self, edge_index_i, size_i, x_j, edge_weight, ptr: OptTensor): edge_weight = softmax(edge_weight, edge_index_i, ptr, size_i) return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def update(self, aggr_out): if self.bias is not None: aggr_out = aggr_out + self.bias if self.normalize: aggr_out = F.normalize(aggr_out, p=2, dim=-1) return aggr_out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) ================================================ FILE: net/brainmsgpassing.py ================================================ import sys import inspect import torch # from torch_geometric.utils import scatter_ from torch_scatter import scatter,scatter_add special_args = [ 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j' ] __size_error_msg__ = ('All tensors which should get mapped to the same source ' 'or target nodes must be of same size in dimension 0.') is_python2 = sys.version_info[0] < 3 getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec class MyMessagePassing(torch.nn.Module): r"""Base class for creating message passing layers .. math:: \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right), where :math:`\square` denotes a differentiable, permutation invariant function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as MLPs. See `here `__ for the accompanying tutorial. Args: aggr (string, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). (default: :obj:`"add"`) flow (string, optional): The flow direction of message passing (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) node_dim (int, optional): The axis along which to propagate. (default: :obj:`0`) """ def __init__(self, aggr='add', flow='source_to_target', node_dim=0): super(MyMessagePassing, self).__init__() self.aggr = aggr assert self.aggr in ['add', 'mean', 'max'] self.flow = flow assert self.flow in ['source_to_target', 'target_to_source'] self.node_dim = node_dim assert self.node_dim >= 0 self.__message_args__ = getargspec(self.message)[0][1:] self.__special_args__ = [(i, arg) for i, arg in enumerate(self.__message_args__) if arg in special_args] self.__message_args__ = [ arg for arg in self.__message_args__ if arg not in special_args ] self.__update_args__ = getargspec(self.update)[0][2:] def propagate(self, edge_index, size=None, **kwargs): r"""The initial call to start propagating messages. Args: edge_index (Tensor): The indices of a general (sparse) assignment matrix with shape :obj:`[N, M]` (can be directed or undirected). size (list or tuple, optional): The size :obj:`[N, M]` of the assignment matrix. If set to :obj:`None`, the size is tried to get automatically inferred and assumed to be symmetric. (default: :obj:`None`) **kwargs: Any additional data which is needed to construct messages and to update node embeddings. """ dim = self.node_dim size = [None, None] if size is None else list(size) assert len(size) == 2 i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) ij = {"_i": i, "_j": j} message_args = [] for arg in self.__message_args__: if arg[-2:] in ij.keys(): tmp = kwargs.get(arg[:-2], None) if tmp is None: # pragma: no cover message_args.append(tmp) else: idx = ij[arg[-2:]] if isinstance(tmp, tuple) or isinstance(tmp, list): assert len(tmp) == 2 if tmp[1 - idx] is not None: if size[1 - idx] is None: size[1 - idx] = tmp[1 - idx].size(dim) if size[1 - idx] != tmp[1 - idx].size(dim): raise ValueError(__size_error_msg__) tmp = tmp[idx] if tmp is None: message_args.append(tmp) else: if size[idx] is None: size[idx] = tmp.size(dim) if size[idx] != tmp.size(dim): raise ValueError(__size_error_msg__) tmp = torch.index_select(tmp, dim, edge_index[idx]) message_args.append(tmp) else: message_args.append(kwargs.get(arg, None)) size[0] = size[1] if size[0] is None else size[0] size[1] = size[0] if size[1] is None else size[1] kwargs['edge_index'] = edge_index kwargs['size'] = size for (idx, arg) in self.__special_args__: if arg[-2:] in ij.keys(): message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]]) else: message_args.insert(idx, kwargs[arg]) update_args = [kwargs[arg] for arg in self.__update_args__] out = self.message(*message_args) # out = scatter_(self.aggr, out, edge_index[i], dim, dim_size=size[i]) out = scatter_add(out, edge_index[i], dim, dim_size=size[i]) out = self.update(out, *update_args) return out def message(self, x_j): # pragma: no cover r"""Constructs messages to node :math:`i` in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`. Can take any argument which was initially passed to :meth:`propagate`. In addition, tensors passed to :meth:`propagate` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. """ return x_j def update(self, aggr_out): # pragma: no cover r"""Updates node embeddings in analogy to :math:`\gamma_{\mathbf{\Theta}}` for each node :math:`i \in \mathcal{V}`. Takes in the output of aggregation as first argument and any argument which was initially passed to :meth:`propagate`.""" return aggr_out ================================================ FILE: net/inits.py ================================================ import math def uniform(size, tensor): bound = 1.0 / math.sqrt(size) if tensor is not None: tensor.data.uniform_(-bound, bound) def kaiming_uniform(tensor, fan, a): if tensor is not None: bound = math.sqrt(6 / ((1 + a**2) * fan)) tensor.data.uniform_(-bound, bound) def glorot(tensor): if tensor is not None: stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) tensor.data.uniform_(-stdv, stdv) def zeros(tensor): if tensor is not None: tensor.data.fill_(0) def ones(tensor): if tensor is not None: tensor.data.fill_(1) ================================================ FILE: requirements.txt ================================================ alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work anaconda-client==1.7.2 anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1610472525955/work anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist appdirs==1.4.4 argh==0.26.2 argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work arrow==0.13.1 ase==3.21.1 asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work astroid @ file:///tmp/build/80754af9/astroid_1613500854201/work astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work atomicwrites==1.4.0 attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work Babel @ file:///tmp/build/80754af9/babel_1607110387436/work backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work binaryornot @ file:///tmp/build/80754af9/binaryornot_1617751525010/work bitarray @ file:///tmp/build/80754af9/bitarray_1618431750766/work bkcharts==0.2 black==19.10b0 bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work bokeh @ file:///tmp/build/80754af9/bokeh_1617824541184/work boto==2.49.0 Bottleneck==1.3.2 brotlipy==0.7.0 certifi==2020.12.5 cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work click @ file:///home/linux1/recipes/ci/click_1610990599742/work cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work clyent==1.2.2 colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work contextlib2==0.6.0.post1 cookiecutter @ file:///tmp/build/80754af9/cookiecutter_1617748928239/work cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work cycler==0.10.0 Cython @ file:///tmp/build/80754af9/cython_1618435160151/work cytoolz==0.11.0 dask @ file:///tmp/build/80754af9/dask-core_1617390489108/work decorator @ file:///tmp/build/80754af9/decorator_1617916966915/work deepdish==0.3.6 defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work distributed @ file:///tmp/build/80754af9/distributed_1617381497899/work docutils @ file:///tmp/build/80754af9/docutils_1617624660125/work entrypoints==0.3 et-xmlfile==1.0.1 fastcache==1.1.0 filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work fsspec @ file:///tmp/build/80754af9/fsspec_1617959894824/work future==0.18.2 gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work gmpy2==2.0.8 googledrivedownloader==0.4 greenlet @ file:///tmp/build/80754af9/greenlet_1611957705398/work h5py==2.10.0 HeapDict==1.0.1 html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work inflection==0.5.1 iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work isodate==0.6.0 isort @ file:///tmp/build/80754af9/isort_1616355431277/work itsdangerous @ file:///home/ktietz/src/ci/itsdangerous_1611932585308/work jdcal==1.4.1 jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work jinja2-time @ file:///tmp/build/80754af9/jinja2-time_1617751524098/work joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work json5==0.9.5 jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work jupyter==1.0.0 jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work keyring @ file:///tmp/build/80754af9/keyring_1614616740399/work kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work llvmlite==0.36.0 locket==0.2.1 lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work MarkupSafe==1.1.1 matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work mccabe==0.6.1 mistune==0.8.4 mkl-fft==1.3.0 mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work mkl-service==2.3.0 mock @ file:///tmp/build/80754af9/mock_1607622725907/work more-itertools @ file:///tmp/build/80754af9/more-itertools_1613676688952/work mpmath==1.2.1 msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work multipledispatch==0.6.0 mypy-extensions==0.4.3 nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work nibabel==3.2.1 nilearn==0.7.1 nltk @ file:///tmp/build/80754af9/nltk_1618327084230/work nose @ file:///tmp/build/80754af9/nose_1606773131901/work notebook @ file:///tmp/build/80754af9/notebook_1616443462982/work numba @ file:///tmp/build/80754af9/numba_1616774046117/work numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1618497241363/work numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work olefile==0.46 openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work pandas==1.2.4 pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work parso==0.7.0 partd @ file:///tmp/build/80754af9/partd_1618000087440/work path @ file:///tmp/build/80754af9/path_1614022220526/work pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024983162/work pathspec==0.7.0 pathtools==0.1.2 patsy==0.5.1 pep8==1.7.1 pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work pkginfo==1.7.0 pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work ply==3.11 poyo @ file:///tmp/build/80754af9/poyo_1617751526755/work prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1618088486455/work prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work protobuf==3.17.0 psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl py @ file:///tmp/build/80754af9/py_1607971587848/work pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work pycosat==0.6.3 pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work pycurl==7.43.0.6 pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1616182067796/work pyerfa @ file:///tmp/build/80754af9/pyerfa_1619390903914/work pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work pylint @ file:///tmp/build/80754af9/pylint_1617135829881/work pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work pyodbc===4.0.0-unsupported pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work pytest==6.2.3 python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work python-louvain==0.15 python-slugify @ file:///tmp/build/80754af9/python-slugify_1620405669636/work pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work PyYAML==5.4.1 pyzmq==20.0.0 QDarkStyle @ file:///tmp/build/80754af9/qdarkstyle_1617386714626/work qstylizer @ file:///tmp/build/80754af9/qstylizer_1617713584600/work/dist/qstylizer-0.1.10-py2.py3-none-any.whl QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work qtconsole @ file:///tmp/build/80754af9/qtconsole_1616775094278/work QtPy==1.9.0 rdflib==5.0.0 regex @ file:///tmp/build/80754af9/regex_1617569202463/work requests @ file:///tmp/build/80754af9/requests_1608241421344/work rope @ file:///tmp/build/80754af9/rope_1602264064449/work Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work scikit-image==0.16.2 scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1614446682169/work scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work simplegeneric==0.8.1 singledispatch @ file:///tmp/build/80754af9/singledispatch_1614366001199/work sip==4.19.13 six @ file:///tmp/build/80754af9/six_1605205327372/work sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work Sphinx @ file:///tmp/build/80754af9/sphinx_1616268783226/work sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work sphinxcontrib-htmlhelp @ file:///home/ktietz/src/ci/sphinxcontrib-htmlhelp_1611920974801/work sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work sphinxcontrib-serializinghtml @ file:///home/ktietz/src/ci/sphinxcontrib-serializinghtml_1611920755253/work sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work spyder @ file:///tmp/build/80754af9/spyder_1618327905127/work spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1617396566288/work SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1618089170652/work statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work tables==3.6.1 tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work tensorboardX==2.2 terminado==0.9.4 testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work text-unidecode==1.3 textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work tinycss @ file:///tmp/build/80754af9/tinycss_1617713798712/work toml @ file:///tmp/build/80754af9/toml_1616166611790/work toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work torch==1.7.0 torch-cluster==1.5.9 torch-geometric==1.7.0 torch-scatter==2.0.6 torch-sparse==0.6.9 torch-spline-conv==1.2.1 torchaudio==0.7.0a0+ac17b64 torchvision==0.8.0 tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work tsBNgen==1.0.0 typed-ast @ file:///tmp/build/80754af9/typed-ast_1610484547928/work typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work unicodecsv==0.14.1 Unidecode @ file:///tmp/build/80754af9/unidecode_1614712377438/work urllib3 @ file:///tmp/build/80754af9/urllib3_1615837158687/work watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work webencodings==0.5.1 Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work whichcraft @ file:///tmp/build/80754af9/whichcraft_1617751293875/work widgetsnbextension==3.5.1 wrapt==1.12.1 wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1617224712951/work xlwt==1.3.0 yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work zict==2.0.0 zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work zope.event==4.5.0 zope.interface @ file:///tmp/build/80754af9/zope.interface_1616357211867/work