Full Code of xxlya/BrainGNN_Pytorch for AI

main 1e337e7a13af cached
24 files
112.3 KB
31.7k tokens
68 symbols
1 requests
Download .txt
Repository: xxlya/BrainGNN_Pytorch
Branch: main
Commit: 1e337e7a13af
Files: 24
Total size: 112.3 KB

Directory structure:
gitextract_qan2yim0/

├── .idea/
│   ├── .gitignore
│   ├── GNN_biomarker_MEDIA.iml
│   ├── deployment.xml
│   ├── encodings.xml
│   ├── inspectionProfiles/
│   │   └── Project_Default.xml
│   ├── misc.xml
│   ├── modules.xml
│   └── webServers.xml
├── 01-fetch_data.py
├── 02-process_data.py
├── 03-main.py
├── README.md
├── data/
│   └── subject_ID.txt
├── imports/
│   ├── ABIDEDataset.py
│   ├── __inits__.py
│   ├── gdc.py
│   ├── preprocess_data.py
│   ├── read_abide_stats_parall.py
│   └── utils.py
├── net/
│   ├── braingnn.py
│   ├── braingraphconv.py
│   ├── brainmsgpassing.py
│   └── inits.py
└── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: .idea/.gitignore
================================================
# Default ignored files
/shelf/
/workspace.xml
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/


================================================
FILE: .idea/GNN_biomarker_MEDIA.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
  <component name="NewModuleRootManager">
    <content url="file://$MODULE_DIR$" />
    <orderEntry type="jdk" jdkName="li-cancer Remote Python 3.8.5 (sftp://xiaoxiaol@localhost:6000/data/xiaoxiaol/anaconda3/envs/cancergnn/bin/python)" jdkType="Python SDK" />
    <orderEntry type="sourceFolder" forTests="false" />
  </component>
</module>

================================================
FILE: .idea/deployment.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="PublishConfigData" serverName="li-gan">
    <serverData>
      <paths name="ipag">
        <serverdata>
          <mappings>
            <mapping deploy="/data/xiaoxiaol/ipag/GNN_HBM/" local="$PROJECT_DIR$/" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="li-gan">
        <serverdata>
          <mappings>
            <mapping deploy="/data/xiaoxiaol/ipag/GNN_HBM" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
      <paths name="xiaoxiaol@localhost:6000 password">
        <serverdata>
          <mappings>
            <mapping deploy="/data/xiaoxiaol/ipag/GNN_HBM" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
    </serverData>
  </component>
</project>

================================================
FILE: .idea/encodings.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="Encoding" addBOMForNewFiles="with NO BOM" />
</project>

================================================
FILE: .idea/inspectionProfiles/Project_Default.xml
================================================
<component name="InspectionProjectProfileManager">
  <profile version="1.0">
    <option name="myName" value="Project Default" />
    <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
      <option name="ignoredPackages">
        <value>
          <list size="200">
            <item index="0" class="java.lang.String" itemvalue="gensim" />
            <item index="1" class="java.lang.String" itemvalue="torch-scatter" />
            <item index="2" class="java.lang.String" itemvalue="unity-scope-colourlovers" />
            <item index="3" class="java.lang.String" itemvalue="scikit-learn" />
            <item index="4" class="java.lang.String" itemvalue="testpath" />
            <item index="5" class="java.lang.String" itemvalue="ufw" />
            <item index="6" class="java.lang.String" itemvalue="py" />
            <item index="7" class="java.lang.String" itemvalue="torchvision" />
            <item index="8" class="java.lang.String" itemvalue="catfish" />
            <item index="9" class="java.lang.String" itemvalue="ipython-genutils" />
            <item index="10" class="java.lang.String" itemvalue="bz2file" />
            <item index="11" class="java.lang.String" itemvalue="python-louvain" />
            <item index="12" class="java.lang.String" itemvalue="bleach" />
            <item index="13" class="java.lang.String" itemvalue="graphviz" />
            <item index="14" class="java.lang.String" itemvalue="lxml" />
            <item index="15" class="java.lang.String" itemvalue="language-selector" />
            <item index="16" class="java.lang.String" itemvalue="jsonschema" />
            <item index="17" class="java.lang.String" itemvalue="xlrd" />
            <item index="18" class="java.lang.String" itemvalue="Werkzeug" />
            <item index="19" class="java.lang.String" itemvalue="wordcloud" />
            <item index="20" class="java.lang.String" itemvalue="python-apt" />
            <item index="21" class="java.lang.String" itemvalue="click" />
            <item index="22" class="java.lang.String" itemvalue="pyxdg" />
            <item index="23" class="java.lang.String" itemvalue="tensorboard" />
            <item index="24" class="java.lang.String" itemvalue="imageio" />
            <item index="25" class="java.lang.String" itemvalue="matplotlib" />
            <item index="26" class="java.lang.String" itemvalue="Keras" />
            <item index="27" class="java.lang.String" itemvalue="Mako" />
            <item index="28" class="java.lang.String" itemvalue="idna" />
            <item index="29" class="java.lang.String" itemvalue="colorgram.py" />
            <item index="30" class="java.lang.String" itemvalue="networkx" />
            <item index="31" class="java.lang.String" itemvalue="pycurl" />
            <item index="32" class="java.lang.String" itemvalue="pluggy" />
            <item index="33" class="java.lang.String" itemvalue="torch-sparse" />
            <item index="34" class="java.lang.String" itemvalue="unity-scope-manpages" />
            <item index="35" class="java.lang.String" itemvalue="screen-resolution-extra" />
            <item index="36" class="java.lang.String" itemvalue="jupyter" />
            <item index="37" class="java.lang.String" itemvalue="PyWavelets" />
            <item index="38" class="java.lang.String" itemvalue="sessioninstaller" />
            <item index="39" class="java.lang.String" itemvalue="smart-open" />
            <item index="40" class="java.lang.String" itemvalue="prompt-toolkit" />
            <item index="41" class="java.lang.String" itemvalue="rcssmin" />
            <item index="42" class="java.lang.String" itemvalue="tensorflow-tensorboard" />
            <item index="43" class="java.lang.String" itemvalue="astor" />
            <item index="44" class="java.lang.String" itemvalue="pathlib2" />
            <item index="45" class="java.lang.String" itemvalue="unity-scope-devhelp" />
            <item index="46" class="java.lang.String" itemvalue="pytest-runner" />
            <item index="47" class="java.lang.String" itemvalue="unity-scope-tomboy" />
            <item index="48" class="java.lang.String" itemvalue="olefile" />
            <item index="49" class="java.lang.String" itemvalue="pytz" />
            <item index="50" class="java.lang.String" itemvalue="python-systemd" />
            <item index="51" class="java.lang.String" itemvalue="traitlets" />
            <item index="52" class="java.lang.String" itemvalue="absl-py" />
            <item index="53" class="java.lang.String" itemvalue="protobuf" />
            <item index="54" class="java.lang.String" itemvalue="joblib" />
            <item index="55" class="java.lang.String" itemvalue="lib" />
            <item index="56" class="java.lang.String" itemvalue="nltk" />
            <item index="57" class="java.lang.String" itemvalue="atomicwrites" />
            <item index="58" class="java.lang.String" itemvalue="pycups" />
            <item index="59" class="java.lang.String" itemvalue="unity-scope-zotero" />
            <item index="60" class="java.lang.String" itemvalue="gast" />
            <item index="61" class="java.lang.String" itemvalue="unity-scope-yelp" />
            <item index="62" class="java.lang.String" itemvalue="pyzmq" />
            <item index="63" class="java.lang.String" itemvalue="oauthlib" />
            <item index="64" class="java.lang.String" itemvalue="entrypoints" />
            <item index="65" class="java.lang.String" itemvalue="tensorflow-gpu" />
            <item index="66" class="java.lang.String" itemvalue="beautifulsoup4" />
            <item index="67" class="java.lang.String" itemvalue="argcomplete" />
            <item index="68" class="java.lang.String" itemvalue="cryptography" />
            <item index="69" class="java.lang.String" itemvalue="Theano" />
            <item index="70" class="java.lang.String" itemvalue="keras-vis" />
            <item index="71" class="java.lang.String" itemvalue="mugshot" />
            <item index="72" class="java.lang.String" itemvalue="widgetsnbextension" />
            <item index="73" class="java.lang.String" itemvalue="tensorly" />
            <item index="74" class="java.lang.String" itemvalue="numexpr" />
            <item index="75" class="java.lang.String" itemvalue="distro" />
            <item index="76" class="java.lang.String" itemvalue="defer" />
            <item index="77" class="java.lang.String" itemvalue="jupyter-core" />
            <item index="78" class="java.lang.String" itemvalue="pydot" />
            <item index="79" class="java.lang.String" itemvalue="menulibre" />
            <item index="80" class="java.lang.String" itemvalue="httplib2" />
            <item index="81" class="java.lang.String" itemvalue="wcwidth" />
            <item index="82" class="java.lang.String" itemvalue="apturl" />
            <item index="83" class="java.lang.String" itemvalue="Jinja2" />
            <item index="84" class="java.lang.String" itemvalue="Keras-Preprocessing" />
            <item index="85" class="java.lang.String" itemvalue="pytest-cov" />
            <item index="86" class="java.lang.String" itemvalue="torch-geometric" />
            <item index="87" class="java.lang.String" itemvalue="coverage" />
            <item index="88" class="java.lang.String" itemvalue="six" />
            <item index="89" class="java.lang.String" itemvalue="plainbox" />
            <item index="90" class="java.lang.String" itemvalue="system-service" />
            <item index="91" class="java.lang.String" itemvalue="parso" />
            <item index="92" class="java.lang.String" itemvalue="ipython" />
            <item index="93" class="java.lang.String" itemvalue="chardet" />
            <item index="94" class="java.lang.String" itemvalue="face-recognition-models" />
            <item index="95" class="java.lang.String" itemvalue="command-not-found" />
            <item index="96" class="java.lang.String" itemvalue="tabulate" />
            <item index="97" class="java.lang.String" itemvalue="PyYAML" />
            <item index="98" class="java.lang.String" itemvalue="pickleshare" />
            <item index="99" class="java.lang.String" itemvalue="SimpleCV" />
            <item index="100" class="java.lang.String" itemvalue="tables" />
            <item index="101" class="java.lang.String" itemvalue="Pygments" />
            <item index="102" class="java.lang.String" itemvalue="imutils" />
            <item index="103" class="java.lang.String" itemvalue="qtconsole" />
            <item index="104" class="java.lang.String" itemvalue="terminado" />
            <item index="105" class="java.lang.String" itemvalue="python-igraph" />
            <item index="106" class="java.lang.String" itemvalue="plyfile" />
            <item index="107" class="java.lang.String" itemvalue="torch-cluster" />
            <item index="108" class="java.lang.String" itemvalue="reportlab" />
            <item index="109" class="java.lang.String" itemvalue="jupyter-client" />
            <item index="110" class="java.lang.String" itemvalue="pexpect" />
            <item index="111" class="java.lang.String" itemvalue="ipykernel" />
            <item index="112" class="java.lang.String" itemvalue="nbconvert" />
            <item index="113" class="java.lang.String" itemvalue="attrs" />
            <item index="114" class="java.lang.String" itemvalue="psutil" />
            <item index="115" class="java.lang.String" itemvalue="svgwrite" />
            <item index="116" class="java.lang.String" itemvalue="jedi" />
            <item index="117" class="java.lang.String" itemvalue="numpy-groupies" />
            <item index="118" class="java.lang.String" itemvalue="padme" />
            <item index="119" class="java.lang.String" itemvalue="pygobject" />
            <item index="120" class="java.lang.String" itemvalue="msgpack" />
            <item index="121" class="java.lang.String" itemvalue="unity-scope-chromiumbookmarks" />
            <item index="122" class="java.lang.String" itemvalue="PyJWT" />
            <item index="123" class="java.lang.String" itemvalue="onboard" />
            <item index="124" class="java.lang.String" itemvalue="pydiffmap" />
            <item index="125" class="java.lang.String" itemvalue="pandocfilters" />
            <item index="126" class="java.lang.String" itemvalue="slimit" />
            <item index="127" class="java.lang.String" itemvalue="unity-scope-virtualbox" />
            <item index="128" class="java.lang.String" itemvalue="lightdm-gtk-greeter-settings" />
            <item index="129" class="java.lang.String" itemvalue="pyasn1" />
            <item index="130" class="java.lang.String" itemvalue="requests" />
            <item index="131" class="java.lang.String" itemvalue="nilearn" />
            <item index="132" class="java.lang.String" itemvalue="XlsxWriter" />
            <item index="133" class="java.lang.String" itemvalue="seaborn" />
            <item index="134" class="java.lang.String" itemvalue="cached-property" />
            <item index="135" class="java.lang.String" itemvalue="xgboost" />
            <item index="136" class="java.lang.String" itemvalue="ipywidgets" />
            <item index="137" class="java.lang.String" itemvalue="blinker" />
            <item index="138" class="java.lang.String" itemvalue="ubuntu-drivers-common" />
            <item index="139" class="java.lang.String" itemvalue="scipy" />
            <item index="140" class="java.lang.String" itemvalue="tornado" />
            <item index="141" class="java.lang.String" itemvalue="opencv-python" />
            <item index="142" class="java.lang.String" itemvalue="unity-scope-firefoxbookmarks" />
            <item index="143" class="java.lang.String" itemvalue="xkit" />
            <item index="144" class="java.lang.String" itemvalue="torch" />
            <item index="145" class="java.lang.String" itemvalue="mistune" />
            <item index="146" class="java.lang.String" itemvalue="pandas" />
            <item index="147" class="java.lang.String" itemvalue="shap" />
            <item index="148" class="java.lang.String" itemvalue="termcolor" />
            <item index="149" class="java.lang.String" itemvalue="torch-spline-conv" />
            <item index="150" class="java.lang.String" itemvalue="future" />
            <item index="151" class="java.lang.String" itemvalue="jupyter-console" />
            <item index="152" class="java.lang.String" itemvalue="unity-scope-texdoc" />
            <item index="153" class="java.lang.String" itemvalue="usb-creator" />
            <item index="154" class="java.lang.String" itemvalue="Pillow" />
            <item index="155" class="java.lang.String" itemvalue="html5lib" />
            <item index="156" class="java.lang.String" itemvalue="Brlapi" />
            <item index="157" class="java.lang.String" itemvalue="python-dateutil" />
            <item index="158" class="java.lang.String" itemvalue="MarkupSafe" />
            <item index="159" class="java.lang.String" itemvalue="feedparser" />
            <item index="160" class="java.lang.String" itemvalue="tflearn" />
            <item index="161" class="java.lang.String" itemvalue="msgpack-numpy" />
            <item index="162" class="java.lang.String" itemvalue="segraph" />
            <item index="163" class="java.lang.String" itemvalue="unattended-upgrades" />
            <item index="164" class="java.lang.String" itemvalue="Markdown" />
            <item index="165" class="java.lang.String" itemvalue="notebook" />
            <item index="166" class="java.lang.String" itemvalue="rpy2" />
            <item index="167" class="java.lang.String" itemvalue="boto" />
            <item index="168" class="java.lang.String" itemvalue="python-gnupg" />
            <item index="169" class="java.lang.String" itemvalue="tensorpack" />
            <item index="170" class="java.lang.String" itemvalue="ssh-import-id" />
            <item index="171" class="java.lang.String" itemvalue="unity-scope-openclipart" />
            <item index="172" class="java.lang.String" itemvalue="panorama" />
            <item index="173" class="java.lang.String" itemvalue="progressbar" />
            <item index="174" class="java.lang.String" itemvalue="virtualenv" />
            <item index="175" class="java.lang.String" itemvalue="Ubuntu-Make" />
            <item index="176" class="java.lang.String" itemvalue="Augmentor" />
            <item index="177" class="java.lang.String" itemvalue="enum34" />
            <item index="178" class="java.lang.String" itemvalue="checkbox-support" />
            <item index="179" class="java.lang.String" itemvalue="scikit-image" />
            <item index="180" class="java.lang.String" itemvalue="guacamole" />
            <item index="181" class="java.lang.String" itemvalue="ptyprocess" />
            <item index="182" class="java.lang.String" itemvalue="more-itertools" />
            <item index="183" class="java.lang.String" itemvalue="simplegeneric" />
            <item index="184" class="java.lang.String" itemvalue="python-debian" />
            <item index="185" class="java.lang.String" itemvalue="python-resize-image" />
            <item index="186" class="java.lang.String" itemvalue="louis" />
            <item index="187" class="java.lang.String" itemvalue="urllib3" />
            <item index="188" class="java.lang.String" itemvalue="Cython" />
            <item index="189" class="java.lang.String" itemvalue="unity-scope-gdrive" />
            <item index="190" class="java.lang.String" itemvalue="pytest" />
            <item index="191" class="java.lang.String" itemvalue="nbformat" />
            <item index="192" class="java.lang.String" itemvalue="xdiagnose" />
            <item index="193" class="java.lang.String" itemvalue="Keras-Applications" />
            <item index="194" class="java.lang.String" itemvalue="scikit-plot" />
            <item index="195" class="java.lang.String" itemvalue="tqdm" />
            <item index="196" class="java.lang.String" itemvalue="grpcio" />
            <item index="197" class="java.lang.String" itemvalue="deepdish" />
            <item index="198" class="java.lang.String" itemvalue="unity-scope-calculator" />
            <item index="199" class="java.lang.String" itemvalue="ply" />
          </list>
        </value>
      </option>
    </inspection_tool>
  </profile>
</component>

================================================
FILE: .idea/misc.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ProjectRootManager" version="2" project-jdk-name="li-cancer Remote Python 3.8.5 (sftp://xiaoxiaol@localhost:6000/data/xiaoxiaol/anaconda3/envs/cancergnn/bin/python)" project-jdk-type="Python SDK" />
  <component name="PyCharmProfessionalAdvertiser">
    <option name="shown" value="true" />
  </component>
</project>

================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ProjectModuleManager">
    <modules>
      <module fileurl="file://$PROJECT_DIR$/.idea/GNN_biomarker_MEDIA.iml" filepath="$PROJECT_DIR$/.idea/GNN_biomarker_MEDIA.iml" />
    </modules>
  </component>
</project>

================================================
FILE: .idea/webServers.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="WebServers">
    <option name="servers">
      <webServer id="58b70e43-c401-48fa-983a-32280c016f57" name="ipag">
        <fileTransfer accessType="SFTP" host="localhost" port="6000" sshConfigId="dba9c212-8899-4954-a857-6abbb7000465" sshConfig="xiaoxiaol@localhost:6000 password">
          <advancedOptions>
            <advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
          </advancedOptions>
        </fileTransfer>
      </webServer>
    </option>
  </component>
</project>

================================================
FILE: 01-fetch_data.py
================================================
# Copyright (c) 2019 Mwiza Kunda
# Copyright (C) 2017 Sarah Parisot <s.parisot@imperial.ac.uk>, , Sofia Ira Ktena <ira.ktena@imperial.ac.uk>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

'''
This script mainly refers to https://github.com/kundaMwiza/fMRI-site-adaptation/blob/master/fetch_data.py
'''

from nilearn import datasets
import argparse
from imports import preprocess_data as Reader
import os
import shutil
import sys

# Input data variables
code_folder = os.getcwd()
root_folder = '/data/'
data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/')
if not os.path.exists(data_folder):
    os.makedirs(data_folder)
shutil.copyfile(os.path.join(root_folder,'subject_ID.txt'), os.path.join(data_folder, 'subject_IDs.txt'))

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main():
    parser = argparse.ArgumentParser(description='Download ABIDE data and compute functional connectivity matrices')
    parser.add_argument('--pipeline', default='cpac', type=str,
                        help='Pipeline to preprocess ABIDE data. Available options are ccs, cpac, dparsf and niak.'
                             ' default: cpac.')
    parser.add_argument('--atlas', default='cc200',
                        help='Brain parcellation atlas. Options: ho, cc200 and cc400, default: cc200.')
    parser.add_argument('--download', default=True, type=str2bool,
                        help='Dowload data or just compute functional connectivity. default: True')
    args = parser.parse_args()
    print(args)

    params = dict()

    pipeline = args.pipeline
    atlas = args.atlas
    download = args.download

    # Files to fetch

    files = ['rois_' + atlas]

    filemapping = {'func_preproc': 'func_preproc.nii.gz',
                   files[0]: files[0] + '.1D'}


    # Download database files
    if download == True:
        abide = datasets.fetch_abide_pcp(data_dir=root_folder, pipeline=pipeline,
                                         band_pass_filtering=True, global_signal_regression=False, derivatives=files,
                                         quality_checked=False)

    subject_IDs = Reader.get_ids() #changed path to data path
    subject_IDs = subject_IDs.tolist()

    # Create a folder for each subject
    for s, fname in zip(subject_IDs, Reader.fetch_filenames(subject_IDs, files[0], atlas)):
        subject_folder = os.path.join(data_folder, s)
        if not os.path.exists(subject_folder):
            os.mkdir(subject_folder)

        # Get the base filename for each subject
        base = fname.split(files[0])[0]

        # Move each subject file to the subject folder
        for fl in files:
            if not os.path.exists(os.path.join(subject_folder, base + filemapping[fl])):
                shutil.move(base + filemapping[fl], subject_folder)

    time_series = Reader.get_timeseries(subject_IDs, atlas)

    # Compute and save connectivity matrices
    Reader.subject_connectivity(time_series, subject_IDs, atlas, 'correlation')
    Reader.subject_connectivity(time_series, subject_IDs, atlas, 'partial correlation')


if __name__ == '__main__':
    main()


================================================
FILE: 02-process_data.py
================================================
# Copyright (c) 2019 Mwiza Kunda
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import sys
import argparse
import pandas as pd
import numpy as np
from imports import preprocess_data as Reader
import deepdish as dd
import warnings
import os

warnings.filterwarnings("ignore")
root_folder = '/data/'
data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal/')

# Process boolean command line arguments
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def main():
    parser = argparse.ArgumentParser(description='Classification of the ABIDE dataset using a Ridge classifier. '
                                                 'MIDA is used to minimize the distribution mismatch between ABIDE sites')
    parser.add_argument('--atlas', default='cc200',
                        help='Atlas for network construction (node definition) options: ho, cc200, cc400, default: cc200.')
    parser.add_argument('--seed', default=123, type=int, help='Seed for random initialisation. default: 1234.')
    parser.add_argument('--nclass', default=2, type=int, help='Number of classes. default:2')


    args = parser.parse_args()
    print('Arguments: \n', args)


    params = dict()

    params['seed'] = args.seed  # seed for random initialisation

    # Algorithm choice
    params['atlas'] = args.atlas  # Atlas for network construction
    atlas = args.atlas  # Atlas for network construction (node definition)

    # Get subject IDs and class labels
    subject_IDs = Reader.get_ids()
    labels = Reader.get_subject_score(subject_IDs, score='DX_GROUP')

    # Number of subjects and classes for binary classification
    num_classes = args.nclass
    num_subjects = len(subject_IDs)
    params['n_subjects'] = num_subjects

    # Initialise variables for class labels and acquisition sites
    # 1 is autism, 2 is control
    y_data = np.zeros([num_subjects, num_classes]) # n x 2
    y = np.zeros([num_subjects, 1]) # n x 1

    # Get class labels for all subjects
    for i in range(num_subjects):
        y_data[i, int(labels[subject_IDs[i]]) - 1] = 1
        y[i] = int(labels[subject_IDs[i]])

    # Compute feature vectors (vectorised connectivity networks)
    fea_corr = Reader.get_networks(subject_IDs, iter_no='', kind='correlation', atlas_name=atlas) #(1035, 200, 200)
    fea_pcorr = Reader.get_networks(subject_IDs, iter_no='', kind='partial correlation', atlas_name=atlas) #(1035, 200, 200)

    if not os.path.exists(os.path.join(data_folder,'raw')):
        os.makedirs(os.path.join(data_folder,'raw'))
    for i, subject in enumerate(subject_IDs):
        dd.io.save(os.path.join(data_folder,'raw',subject+'.h5'),{'corr':fea_corr[i],'pcorr':fea_pcorr[i],'label':y[i]%2})

if __name__ == '__main__':
    main()


================================================
FILE: 03-main.py
================================================
import os
import numpy as np
import argparse
import time
import copy

import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
from tensorboardX import SummaryWriter

from imports.ABIDEDataset import ABIDEDataset
from torch_geometric.data import DataLoader
from net.braingnn import Network
from imports.utils import train_val_test_split
from sklearn.metrics import classification_report, confusion_matrix

torch.manual_seed(123)

EPS = 1e-10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batchSize', type=int, default=100, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal', help='root directory of the dataset')
parser.add_argument('--fold', type=int, default=0, help='training which fold')
parser.add_argument('--lr', type = float, default=0.01, help='learning rate')
parser.add_argument('--stepsize', type=int, default=20, help='scheduler step size')
parser.add_argument('--gamma', type=float, default=0.5, help='scheduler shrinking rate')
parser.add_argument('--weightdecay', type=float, default=5e-3, help='regularization')
parser.add_argument('--lamb0', type=float, default=1, help='classification loss weight')
parser.add_argument('--lamb1', type=float, default=0, help='s1 unit regularization')
parser.add_argument('--lamb2', type=float, default=0, help='s2 unit regularization')
parser.add_argument('--lamb3', type=float, default=0.1, help='s1 entropy regularization')
parser.add_argument('--lamb4', type=float, default=0.1, help='s2 entropy regularization')
parser.add_argument('--lamb5', type=float, default=0.1, help='s1 consistence regularization')
parser.add_argument('--layer', type=int, default=2, help='number of GNN layers')
parser.add_argument('--ratio', type=float, default=0.5, help='pooling ratio')
parser.add_argument('--indim', type=int, default=200, help='feature dim')
parser.add_argument('--nroi', type=int, default=200, help='num of ROIs')
parser.add_argument('--nclass', type=int, default=2, help='num of classes')
parser.add_argument('--load_model', type=bool, default=False)
parser.add_argument('--save_model', type=bool, default=True)
parser.add_argument('--optim', type=str, default='Adam', help='optimization method: SGD, Adam')
parser.add_argument('--save_path', type=str, default='./model/', help='path to save model')
opt = parser.parse_args()

if not os.path.exists(opt.save_path):
    os.makedirs(opt.save_path)

#################### Parameter Initialization #######################
path = opt.dataroot
name = 'ABIDE'
save_model = opt.save_model
load_model = opt.load_model
opt_method = opt.optim
num_epoch = opt.n_epochs
fold = opt.fold
writer = SummaryWriter(os.path.join('./log',str(fold)))



################## Define Dataloader ##################################

dataset = ABIDEDataset(path,name)
dataset.data.y = dataset.data.y.squeeze()
dataset.data.x[dataset.data.x == float('inf')] = 0

tr_index,val_index,te_index = train_val_test_split(fold=fold)
train_dataset = dataset[tr_index]
val_dataset = dataset[val_index]
test_dataset = dataset[te_index]


train_loader = DataLoader(train_dataset,batch_size=opt.batchSize, shuffle= True)
val_loader = DataLoader(val_dataset, batch_size=opt.batchSize, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=False)



############### Define Graph Deep Learning Network ##########################
model = Network(opt.indim,opt.ratio,opt.nclass).to(device)
print(model)

if opt_method == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr= opt.lr, weight_decay=opt.weightdecay)
elif opt_method == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr =opt.lr, momentum = 0.9, weight_decay=opt.weightdecay, nesterov = True)

scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.stepsize, gamma=opt.gamma)

############################### Define Other Loss Functions ########################################
def topk_loss(s,ratio):
    if ratio > 0.5:
        ratio = 1-ratio
    s = s.sort(dim=1).values
    res =  -torch.log(s[:,-int(s.size(1)*ratio):]+EPS).mean() -torch.log(1-s[:,:int(s.size(1)*ratio)]+EPS).mean()
    return res


def consist_loss(s):
    if len(s) == 0:
        return 0
    s = torch.sigmoid(s)
    W = torch.ones(s.shape[0],s.shape[0])
    D = torch.eye(s.shape[0])*torch.sum(W,dim=1)
    L = D-W
    L = L.to(device)
    res = torch.trace(torch.transpose(s,0,1) @ L @ s)/(s.shape[0]*s.shape[0])
    return res

###################### Network Training Function#####################################
def train(epoch):
    print('train...........')
    scheduler.step()

    for param_group in optimizer.param_groups:
        print("LR", param_group['lr'])
    model.train()
    s1_list = []
    s2_list = []
    loss_all = 0
    step = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, w1, w2, s1, s2 = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos)
        s1_list.append(s1.view(-1).detach().cpu().numpy())
        s2_list.append(s2.view(-1).detach().cpu().numpy())

        loss_c = F.nll_loss(output, data.y)

        loss_p1 = (torch.norm(w1, p=2)-1) ** 2
        loss_p2 = (torch.norm(w2, p=2)-1) ** 2
        loss_tpk1 = topk_loss(s1,opt.ratio)
        loss_tpk2 = topk_loss(s2,opt.ratio)
        loss_consist = 0
        for c in range(opt.nclass):
            loss_consist += consist_loss(s1[data.y == c])
        loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \
                   + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist
        writer.add_scalar('train/classification_loss', loss_c, epoch*len(train_loader)+step)
        writer.add_scalar('train/unit_loss1', loss_p1, epoch*len(train_loader)+step)
        writer.add_scalar('train/unit_loss2', loss_p2, epoch*len(train_loader)+step)
        writer.add_scalar('train/TopK_loss1', loss_tpk1, epoch*len(train_loader)+step)
        writer.add_scalar('train/TopK_loss2', loss_tpk2, epoch*len(train_loader)+step)
        writer.add_scalar('train/GCL_loss', loss_consist, epoch*len(train_loader)+step)
        step = step + 1

        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()

        s1_arr = np.hstack(s1_list)
        s2_arr = np.hstack(s2_list)
    return loss_all / len(train_dataset), s1_arr, s2_arr ,w1,w2


###################### Network Testing Function#####################################
def test_acc(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos)
        pred = outputs[0].max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()

    return correct / len(loader.dataset)

def test_loss(loader,epoch):
    print('testing...........')
    model.eval()
    loss_all = 0
    for data in loader:
        data = data.to(device)
        output, w1, w2, s1, s2= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos)
        loss_c = F.nll_loss(output, data.y)

        loss_p1 = (torch.norm(w1, p=2)-1) ** 2
        loss_p2 = (torch.norm(w2, p=2)-1) ** 2
        loss_tpk1 = topk_loss(s1,opt.ratio)
        loss_tpk2 = topk_loss(s2,opt.ratio)
        loss_consist = 0
        for c in range(opt.nclass):
            loss_consist += consist_loss(s1[data.y == c])
        loss = opt.lamb0*loss_c + opt.lamb1 * loss_p1 + opt.lamb2 * loss_p2 \
                   + opt.lamb3 * loss_tpk1 + opt.lamb4 *loss_tpk2 + opt.lamb5* loss_consist

        loss_all += loss.item() * data.num_graphs
    return loss_all / len(loader.dataset)

#######################################################################################
############################   Model Training #########################################
#######################################################################################
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
for epoch in range(0, num_epoch):
    since  = time.time()
    tr_loss, s1_arr, s2_arr, w1, w2 = train(epoch)
    tr_acc = test_acc(train_loader)
    val_acc = test_acc(val_loader)
    val_loss = test_loss(val_loader,epoch)
    time_elapsed = time.time() - since
    print('*====**')
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Train Acc: {:.7f}, Test Loss: {:.7f}, Test Acc: {:.7f}'.format(epoch, tr_loss,
                                                       tr_acc, val_loss, val_acc))

    writer.add_scalars('Acc',{'train_acc':tr_acc,'val_acc':val_acc},  epoch)
    writer.add_scalars('Loss', {'train_loss': tr_loss, 'val_loss': val_loss},  epoch)
    writer.add_histogram('Hist/hist_s1', s1_arr, epoch)
    writer.add_histogram('Hist/hist_s2', s2_arr, epoch)

    if val_loss < best_loss and epoch > 5:
        print("saving best model")
        best_loss = val_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        if save_model:
            torch.save(best_model_wts, os.path.join(opt.save_path,str(fold)+'.pth'))

#######################################################################################
######################### Testing on testing set ######################################
#######################################################################################

if opt.load_model:
    model = Network(opt.indim,opt.ratio,opt.nclass).to(device)
    model.load_state_dict(torch.load(os.path.join(opt.save_path,str(fold)+'.pth')))
    model.eval()
    preds = []
    correct = 0
    for data in val_loader:
        data = data.to(device)
        outputs= model(data.x, data.edge_index, data.batch, data.edge_attr,data.pos)
        pred = outputs[0].max(1)[1]
        preds.append(pred.cpu().detach().numpy())
        correct += pred.eq(data.y).sum().item()
    preds = np.concatenate(preds,axis=0)
    trues = val_dataset.data.y.cpu().detach().numpy()
    cm = confusion_matrix(trues,preds)
    print("Confusion matrix")
    print(classification_report(trues, preds))

else:
   model.load_state_dict(best_model_wts)
   model.eval()
   test_accuracy = test_acc(test_loader)
   test_l= test_loss(test_loader,0)
   print("===========================")
   print("Test Acc: {:.7f}, Test Loss: {:.7f} ".format(test_accuracy, test_l))
   print(opt)



================================================
FILE: README.md
================================================
# Graph Neural Network for Brain Network Analysis
 A preliminary implementation of BrainGNN. The example presented here is on the public resting-state fMRI ABIDE for the convenience of development. This dataset was different from the ones used in our publication, which are cleaner task-fMRI. Still seeking solutions improve representation learning on the noisy data.


## Usage
### Setup
**pip**

See the `requirements.txt` for environment configuration. 
```bash
pip install -r requirements.txt
```
**PYG**

To install pyg library, [please refer to the document](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)

### Dataset 
**ABIDE**

We treat each fMRI as a brain graph. How to download and construct the graphs?
```
python 01-fetch_data.py
python 02-process_data.py
```

### How to run classification?
Training and testing are integrated in file `main.py`. To run
```
python 03-main.py 
```


## Citation
If you find the code and dataset useful, please cite our paper.
```latex
@article{li2020braingnn,
  title={Braingnn: Interpretable brain graph neural network for fmri analysis},
  author={Li, Xiaoxiao and Zhou,Yuan and Dvornek, Nicha and Zhang, Muhan and Gao, Siyuan and Zhuang, Juntang and Scheinost, Dustin and Staib, Lawrence and Ventola, Pamela and Duncan, James},
  journal={bioRxiv},
  year={2020},
  publisher={Cold Spring Harbor Laboratory}
}
```


================================================
FILE: data/subject_ID.txt
================================================
50128
51203
50325
50117
50573
50741
50779
51009
50746
50574
50110
50322
51036
51204
50119
50126
50314
51490
50784
51464
51000
51038
50748
51235
51007
51463
50783
50777
50313
50121
51053
51261
50723
50511
51295
50347
50982
50976
51098
51292
50340
50516
50724
51266
51054
50186
50529
50985
50520
50376
50978
50144
51096
50382
51250
51062
50349
51065
50385
51257
50143
51091
50371
50527
51268
50188
50518
50749
51039
50776
50120
50312
51006
51234
50782
51462
50118
51465
50785
51001
50315
50127
51491
51008
50778
51205
50575
50747
50111
50129
50116
50324
50740
50572
51030
51202
50370
50142
51090
50526
51256
51064
50519
50189
51269
51063
50383
51251
50521
50145
51097
50979
50377
50348
51055
50187
51267
51293
50341
50725
51258
50984
50528
50970
50510
50722
51294
50346
51260
51052
51099
50977
50379
50983
50039
50496
51312
50234
50006
50650
50802
50668
51118
50657
50233
51127
51315
50491
50008
50498
50037
50205
50661
51581
50453
50695
51575
51111
51323
51129
50659
51324
51116
51572
50692
50666
50202
50030
51142
51370
50269
51189
50251
50407
50438
51348
50603
50267
51187
50055
51341
50293
51173
51174
51346
50294
51180
50052
50260
50604
50436
50658
51128
50667
50455
50031
50203
51117
51325
50693
51573
50499
50009
51574
50694
51322
51110
50204
50036
51580
50660
50803
50669
51314
51126
50490
50656
50232
50038
50804
50007
50235
50651
50463
50497
51121
51313
50261
51181
50053
50437
50605
51347
50295
51175
50408
51172
51340
50292
50602
51186
50054
50266
50259
50250
50406
51349
50439
50257
51188
50268
51195
50047
50275
50611
51161
51353
50281
51159
51354
50286
51166
50424
50616
50272
51192
50040
50049
51362
51150
50412
50620
50618
50288
51168
50627
50415
50243
51365
50441
50217
50025
50819
51331
51103
51567
50687
50826
51558
51560
51104
51336
50022
50210
50446
51309
50821
51132
51300
51556
50642
50470
50014
51569
50689
50817
50013
50477
50645
50483
51307
51135
50448
51338
51169
50289
50619
51364
51156
50414
50626
50242
50048
50245
50621
50413
51151
51363
50628
50617
50425
51193
50041
50273
51167
51355
50287
51352
50280
51160
50274
51194
50046
50422
50610
50482
51134
51306
50012
50644
51339
50449
50643
50015
51301
51133
50485
51557
50816
50688
51568
50211
50023
50447
51561
51105
50820
51308
51102
51330
50686
51566
50440
50818
50024
50216
51559
50169
50955
50156
51084
50364
50700
50532
51070
50390
51048
50952
50738
50397
51077
50999
50707
50363
51083
50990
50158
50964
51273
51041
50193
50355
50167
50503
50731
50709
50399
51079
50997
50736
50504
50160
51280
50352
51046
50194
51274
51482
50306
50134
51220
51012
51476
50796
50339
50791
51471
51015
51227
50133
50301
50557
51485
51218
50568
51023
51211
50753
50561
50105
50337
51478
50798
50308
50330
50102
50566
50754
51216
51024
50559
51229
50996
51078
50962
50708
51275
51047
50195
50505
50737
51281
50353
50161
50965
50159
50991
50166
50354
50730
50502
51040
50192
51272
50739
51049
50706
50150
51082
50362
50998
51076
50954
50168
50391
51071
50365
50157
51085
50701
51025
51217
50103
50331
50755
50567
51228
50558
50560
50752
50336
50104
51210
50799
51479
50300
50132
50556
51484
51470
50790
51226
51014
50569
51219
51013
51221
50797
51477
50551
51483
50135
50307
50338
50171
50343
51291
50727
50515
50185
51057
51265
50972
50388
50986
51068
51262
50182
51050
51606
50344
51296
50981
50149
51254
50386
50988
51066
50372
50524
51059
50711
50523
51095
50147
50375
51061
51253
50381
51298
51238
50577
50745
50321
50113
51207
51035
51469
50789
50319
51456
51032
50114
50326
50742
50570
51209
51236
50780
51460
50774
50122
50310
51458
50317
50125
51493
50773
51467
50787
51231
51003
51252
50380
51060
50710
50374
51094
50146
51299
51093
50373
50525
51067
50989
51255
50387
50728
51058
50345
51297
50183
51051
51263
51607
50148
50974
51264
50184
51056
50342
50170
50514
50726
51069
50987
50973
51459
50329
50786
51466
51002
51230
50124
50316
50772
51492
50578
51208
50775
50311
50123
51237
51461
50781
50318
50788
51468
50327
50115
50571
50743
51457
51201
51033
51239
51034
51206
50744
50576
50112
50320
50060
50252
50404
51146
50609
50299
51179
51373
51141
50403
50255
50058
50297
51345
51177
50263
50051
51183
50435
50607
51148
50056
51184
50264
51170
50290
51342
50801
51329
50466
50654
51316
51124
50492
51578
50698
50208
51123
51311
50005
50237
50653
51318
50468
51327
50691
51571
50665
51585
50033
50201
50239
50206
50034
51582
51576
50696
51320
51112
50291
51343
51171
50433
50601
50265
50057
51185
50050
51182
50262
50606
50434
50296
51344
51149
50402
50254
51140
50059
51147
50253
50405
51178
50298
50608
50697
51577
51113
51321
50035
50207
50663
51583
50469
51319
51584
50664
50200
50032
51326
51114
51570
50690
50807
50209
50699
51579
50236
50004
50652
50494
51122
51328
50800
51317
50493
50655
50467
50003
51563
50683
51335
51107
50213
50445
51138
50648
50822
50442
50026
50214
51100
51332
51564
50019
50825
50489
50010
50646
50480
51136
51304
51109
51303
51131
50487
50017
50028
50814
50418
51165
50285
51357
50615
50427
50043
51191
50271
50249
50276
50044
51196
50612
50282
51350
51162
51359
50416
50624
50240
51154
50278
51198
51153
51361
50247
50623
50411
50016
51130
51302
50486
50815
50029
50481
51305
51137
50011
50647
50812
51333
51101
51565
50685
50443
50215
50027
50488
50824
50020
50212
50444
50682
51562
51106
51334
50649
50823
51139
51199
50279
50246
50410
50622
51360
51152
51358
50428
51155
50625
50417
50241
50248
51163
50283
51351
50045
51197
50277
50613
50421
50419
51369
50426
50614
50270
50042
51190
50284
51356
51164
51472
50792
51224
51016
50302
50130
51486
50554
51029
51481
50553
50305
51011
51223
50795
50333
50757
50565
51027
51215
51488
51018
51212
51020
50562
50750
50334
50106
51279
50199
50509
51074
50704
51080
50152
50360
50956
50358
50367
51087
50969
50531
50703
51241
51073
50960
50994
51248
50507
50735
50351
50163
51277
50197
51045
50993
50369
51089
50967
50190
51042
50164
50958
50356
50732
50500
50751
50563
50107
50335
51021
51213
51214
51026
50332
50564
50756
51019
51489
51222
51010
51474
50794
51480
50552
50304
50136
50109
50131
50303
51487
50555
50793
51473
51017
51225
51028
50966
51088
50368
50992
50357
50959
50501
50733
51271
50191
51249
50995
50961
50196
51044
51276
50162
50350
51282
50359
50957
51072
51240
50968
51086
50366
50702
50530
50198
51278
50705
50361
51081
50153
51075


================================================
FILE: imports/ABIDEDataset.py
================================================
import torch
from torch_geometric.data import InMemoryDataset,Data
from os.path import join, isfile
from os import listdir
import numpy as np
import os.path as osp
from imports.read_abide_stats_parall import read_data


class ABIDEDataset(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None):
        self.root = root
        self.name = name
        super(ABIDEDataset, self).__init__(root,transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        data_dir = osp.join(self.root,'raw')
        onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))]
        onlyfiles.sort()
        return onlyfiles
    @property
    def processed_file_names(self):
        return  'data.pt'

    def download(self):
        # Download to `self.raw_dir`.
        return

    def process(self):
        # Read data into huge `Data` list.
        self.data, self.slices = read_data(self.raw_dir)

        if self.pre_filter is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            data_list = [data for data in data_list if self.pre_filter(data)]
            self.data, self.slices = self.collate(data_list)

        if self.pre_transform is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            data_list = [self.pre_transform(data) for data in data_list]
            self.data, self.slices = self.collate(data_list)

        torch.save((self.data, self.slices), self.processed_paths[0])

    def __repr__(self):
        return '{}({})'.format(self.name, len(self))


================================================
FILE: imports/__inits__.py
================================================


================================================
FILE: imports/gdc.py
================================================
import torch
import numba
import numpy as np
from scipy.linalg import expm
from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj
from torch_sparse import coalesce
from torch_scatter import scatter_add


def jit():
    def decorator(func):
        try:
            return numba.jit(cache=True)(func)
        except RuntimeError:
            return numba.jit(cache=False)(func)

    return decorator


class GDC(object):
    r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
    `"Diffusion Improves Graph Learning" <https://www.kdd.in.tum.de/gdc>`_
    paper.
    .. note::
        The paper offers additional advice on how to choose the
        hyperparameters.
        For an example of using GCN with GDC, see `examples/gcn.py
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
        gcn.py>`_.
    Args:
        self_loop_weight (float, optional): Weight of the added self-loop.
            Set to :obj:`None` to add no self-loops. (default: :obj:`1`)
        normalization_in (str, optional): Normalization of the transition
            matrix on the original (input) graph. Possible values:
            :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`.
            See :func:`GDC.transition_matrix` for details.
            (default: :obj:`"sym"`)
        normalization_out (str, optional): Normalization of the transition
            matrix on the transformed GDC (output) graph. Possible values:
            :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`.
            See :func:`GDC.transition_matrix` for details.
            (default: :obj:`"col"`)
        diffusion_kwargs (dict, optional): Dictionary containing the parameters
            for diffusion.
            `method` specifies the diffusion method (:obj:`"ppr"`,
            :obj:`"heat"` or :obj:`"coeff"`).
            Each diffusion method requires different additional parameters.
            See :func:`GDC.diffusion_matrix_exact` or
            :func:`GDC.diffusion_matrix_approx` for details.
            (default: :obj:`dict(method='ppr', alpha=0.15)`)
        sparsification_kwargs (dict, optional): Dictionary containing the
            parameters for sparsification.
            `method` specifies the sparsification method (:obj:`"threshold"` or
            :obj:`"topk"`).
            Each sparsification method requires different additional
            parameters.
            See :func:`GDC.sparsify_dense` for details.
            (default: :obj:`dict(method='threshold', avg_degree=64)`)
        exact (bool, optional): Whether to exactly calculate the diffusion
            matrix.
            Note that the exact variants are not scalable.
            They densify the adjacency matrix and calculate either its inverse
            or its matrix exponential.
            However, the approximate variants do not support edge weights and
            currently only personalized PageRank and sparsification by
            threshold are implemented as fast, approximate versions.
            (default: :obj:`True`)
    :rtype: :class:`torch_geometric.data.Data`
    """
    def __init__(self, self_loop_weight=1, normalization_in='sym',
                 normalization_out='col',
                 diffusion_kwargs=dict(method='ppr', alpha=0.15),
                 sparsification_kwargs=dict(method='threshold',
                                            avg_degree=64), exact=True):
        self.self_loop_weight = self_loop_weight
        self.normalization_in = normalization_in
        self.normalization_out = normalization_out
        self.diffusion_kwargs = diffusion_kwargs
        self.sparsification_kwargs = sparsification_kwargs
        self.exact = exact

        if self_loop_weight:
            assert exact or self_loop_weight == 1

    @torch.no_grad()
    def __call__(self, data):
        N = data.num_nodes
        edge_index = data.edge_index
        if data.edge_attr is None:
            edge_weight = torch.ones(edge_index.size(1),
                                     device=edge_index.device)
        else:
            edge_weight = data.edge_attr
            assert self.exact
            assert edge_weight.dim() == 1

        if self.self_loop_weight:
            edge_index, edge_weight = add_self_loops(
                edge_index, edge_weight, fill_value=self.self_loop_weight,
                num_nodes=N)

        edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)

        if self.exact:
            edge_index, edge_weight = self.transition_matrix(
                edge_index, edge_weight, N, self.normalization_in)
            diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N,
                                                   **self.diffusion_kwargs)
            edge_index, edge_weight = self.sparsify_dense(
                diff_mat, **self.sparsification_kwargs)
        else:
            edge_index, edge_weight = self.diffusion_matrix_approx(
                edge_index, edge_weight, N, self.normalization_in,
                **self.diffusion_kwargs)
            edge_index, edge_weight = self.sparsify_sparse(
                edge_index, edge_weight, N, **self.sparsification_kwargs)

        edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)
        edge_index, edge_weight = self.transition_matrix(
            edge_index, edge_weight, N, self.normalization_out)

        data.edge_index = edge_index
        data.edge_attr = edge_weight

        return data

    def transition_matrix(self, edge_index, edge_weight, num_nodes,
                          normalization):
        r"""Calculate the approximate, sparse diffusion on a given sparse
        matrix.
        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            normalization (str): Normalization scheme:
                1. :obj:`"sym"`: Symmetric normalization
                   :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A}
                   \mathbf{D}^{-1/2}`.
                2. :obj:`"col"`: Column-wise normalization
                   :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`.
                3. :obj:`"row"`: Row-wise normalization
                   :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`.
                4. :obj:`None`: No normalization.
        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        if normalization == 'sym':
            row, col = edge_index
            deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        elif normalization == 'col':
            _, col = edge_index
            deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
            deg_inv = 1. / deg
            deg_inv[deg_inv == float('inf')] = 0
            edge_weight = edge_weight * deg_inv[col]
        elif normalization == 'row':
            row, _ = edge_index
            deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
            deg_inv = 1. / deg
            deg_inv[deg_inv == float('inf')] = 0
            edge_weight = edge_weight * deg_inv[row]
        elif normalization is None:
            pass
        else:
            raise ValueError(
                'Transition matrix normalization {} unknown.'.format(
                    normalization))

        return edge_index, edge_weight

    def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes,
                               method, **kwargs):
        r"""Calculate the (dense) diffusion on a given sparse graph.
        Note that these exact variants are not scalable. They densify the
        adjacency matrix and calculate either its inverse or its matrix
        exponential.
        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            method (str): Diffusion method:
                1. :obj:`"ppr"`: Use personalized PageRank as diffusion.
                   Additionally expects the parameter:
                   - **alpha** (*float*) - Return probability in PPR.
                     Commonly lies in :obj:`[0.05, 0.2]`.
                2. :obj:`"heat"`: Use heat kernel diffusion.
                   Additionally expects the parameter:
                   - **t** (*float*) - Time of diffusion. Commonly lies in
                     :obj:`[2, 10]`.
                3. :obj:`"coeff"`: Freely choose diffusion coefficients.
                   Additionally expects the parameter:
                   - **coeffs** (*List[float]*) - List of coefficients
                     :obj:`theta_k` for each power of the transition matrix
                     (starting at :obj:`0`).
        :rtype: (:class:`Tensor`)
        """
        if method == 'ppr':
            # α (I_n + (α - 1) A)^-1
            edge_weight = (kwargs['alpha'] - 1) * edge_weight
            edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                     fill_value=1,
                                                     num_nodes=num_nodes)
            mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
            diff_matrix = kwargs['alpha'] * torch.inverse(mat)

        elif method == 'heat':
            # exp(t (A - I_n))
            edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                     fill_value=-1,
                                                     num_nodes=num_nodes)
            edge_weight = kwargs['t'] * edge_weight
            mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
            undirected = is_undirected(edge_index, edge_weight, num_nodes)
            diff_matrix = self.__expm__(mat, undirected)

        elif method == 'coeff':
            adj_matrix = to_dense_adj(edge_index,
                                      edge_attr=edge_weight).squeeze()
            mat = torch.eye(num_nodes, device=edge_index.device)

            diff_matrix = kwargs['coeffs'][0] * mat
            for coeff in kwargs['coeffs'][1:]:
                mat = mat @ adj_matrix
                diff_matrix += coeff * mat
        else:
            raise ValueError('Exact GDC diffusion {} unknown.'.format(method))

        return diff_matrix

    def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes,
                                normalization, method, **kwargs):
        r"""Calculate the approximate, sparse diffusion on a given sparse
        graph.
        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            normalization (str): Transition matrix normalization scheme
                (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`).
                See :func:`GDC.transition_matrix` for details.
            method (str): Diffusion method:
                1. :obj:`"ppr"`: Use personalized PageRank as diffusion.
                   Additionally expects the parameters:
                   - **alpha** (*float*) - Return probability in PPR.
                     Commonly lies in :obj:`[0.05, 0.2]`.
                   - **eps** (*float*) - Threshold for PPR calculation stopping
                     criterion (:obj:`edge_weight >= eps * out_degree`).
                     Recommended default: :obj:`1e-4`.
        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        if method == 'ppr':
            if normalization == 'sym':
                # Calculate original degrees.
                _, col = edge_index
                deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)

            edge_index_np = edge_index.cpu().numpy()
            # Assumes coalesced edge_index.
            _, indptr, out_degree = np.unique(edge_index_np[0],
                                              return_index=True,
                                              return_counts=True)

            neighbors, neighbor_weights = GDC.__calc_ppr__(
                indptr, edge_index_np[1], out_degree, kwargs['alpha'],
                kwargs['eps'])
            ppr_normalization = 'col' if normalization == 'col' else 'row'
            edge_index, edge_weight = self.__neighbors_to_graph__(
                neighbors, neighbor_weights, ppr_normalization,
                device=edge_index.device)
            edge_index = edge_index.to(torch.long)

            if normalization == 'sym':
                # We can change the normalization from row-normalized to
                # symmetric by multiplying the resulting matrix with D^{1/2}
                # from the left and D^{-1/2} from the right.
                # Since we use the original degrees for this it will be like
                # we had used symmetric normalization from the beginning
                # (except for errors due to approximation).
                row, col = edge_index
                deg_inv = deg.sqrt()
                deg_inv_sqrt = deg.pow(-0.5)
                deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
                edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col]
            elif normalization in ['col', 'row']:
                pass
            else:
                raise ValueError(
                    ('Transition matrix normalization {} not implemented for '
                     'non-exact GDC computation.').format(normalization))

        elif method == 'heat':
            raise NotImplementedError(
                ('Currently no fast heat kernel is implemented. You are '
                 'welcome to create one yourself, e.g., based on '
                 '"Kloster and Gleich: Heat kernel based community detection '
                 '(KDD 2014)."'))
        else:
            raise ValueError(
                'Approximate GDC diffusion {} unknown.'.format(method))

        return edge_index, edge_weight

    def sparsify_dense(self, matrix, method, **kwargs):
        r"""Sparsifies the given dense matrix.
        Args:
            matrix (Tensor): Matrix to sparsify.
            num_nodes (int): Number of nodes.
            method (str): Method of sparsification. Options:
                1. :obj:`"threshold"`: Remove all edges with weights smaller
                   than :obj:`eps`.
                   Additionally expects one of these parameters:
                   - **eps** (*float*) - Threshold to bound edges at.
                   - **avg_degree** (*int*) - If :obj:`eps` is not given,
                     it can optionally be calculated by calculating the
                     :obj:`eps` required to achieve a given :obj:`avg_degree`.
                2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per
                   node (column).
                   Additionally expects the following parameters:
                   - **k** (*int*) - Specifies the number of edges to keep.
                   - **dim** (*int*) - The axis along which to take the top
                     :obj:`k`.
        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        assert matrix.shape[0] == matrix.shape[1]
        N = matrix.shape[1]

        if method == 'threshold':
            if 'eps' not in kwargs.keys():
                kwargs['eps'] = self.__calculate_eps__(matrix, N,
                                                       kwargs['avg_degree'])

            edge_index = torch.nonzero(matrix >= kwargs['eps']).t()
            edge_index_flat = edge_index[0] * N + edge_index[1]
            edge_weight = matrix.flatten()[edge_index_flat]

        elif method == 'topk':
            assert kwargs['dim'] in [0, 1]
            sort_idx = torch.argsort(matrix, dim=kwargs['dim'],
                                     descending=True)
            if kwargs['dim'] == 0:
                top_idx = sort_idx[:kwargs['k']]
                edge_weight = torch.gather(matrix, dim=kwargs['dim'],
                                           index=top_idx).flatten()

                row_idx = torch.arange(0, N, device=matrix.device).repeat(
                    kwargs['k'])
                edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0)
            else:
                top_idx = sort_idx[:, :kwargs['k']]
                edge_weight = torch.gather(matrix, dim=kwargs['dim'],
                                           index=top_idx).flatten()

                col_idx = torch.arange(
                    0, N, device=matrix.device).repeat_interleave(kwargs['k'])
                edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0)
        else:
            raise ValueError('GDC sparsification {} unknown.'.format(method))

        return edge_index, edge_weight

    def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method,
                        **kwargs):
        r"""Sparsifies a given sparse graph further.
        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            method (str): Method of sparsification:
                1. :obj:`"threshold"`: Remove all edges with weights smaller
                   than :obj:`eps`.
                   Additionally expects one of these parameters:
                   - **eps** (*float*) - Threshold to bound edges at.
                   - **avg_degree** (*int*) - If :obj:`eps` is not given,
                     it can optionally be calculated by calculating the
                     :obj:`eps` required to achieve a given :obj:`avg_degree`.
        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        if method == 'threshold':
            if 'eps' not in kwargs.keys():
                kwargs['eps'] = self.__calculate_eps__(edge_weight, num_nodes,
                                                       kwargs['avg_degree'])

            remaining_edge_idx = torch.nonzero(
                edge_weight >= kwargs['eps']).flatten()
            edge_index = edge_index[:, remaining_edge_idx]
            edge_weight = edge_weight[remaining_edge_idx]
        elif method == 'topk':
            raise NotImplementedError(
                'Sparse topk sparsification not implemented.')
        else:
            raise ValueError('GDC sparsification {} unknown.'.format(method))

        return edge_index, edge_weight

    def __expm__(self, matrix, symmetric):
        r"""Calculates matrix exponential.
        Args:
            matrix (Tensor): Matrix to take exponential of.
            symmetric (bool): Specifies whether the matrix is symmetric.
        :rtype: (:class:`Tensor`)
        """
        if symmetric:
            e, V = torch.symeig(matrix, eigenvectors=True)
            diff_mat = V @ torch.diag(e.exp()) @ V.t()
        else:
            diff_mat_np = expm(matrix.cpu().numpy())
            diff_mat = torch.Tensor(diff_mat_np).to(matrix.device)
        return diff_mat

    def __calculate_eps__(self, matrix, num_nodes, avg_degree):
        r"""Calculates threshold necessary to achieve a given average degree.
        Args:
            matrix (Tensor): Adjacency matrix or edge weights.
            num_nodes (int): Number of nodes.
            avg_degree (int): Target average degree.
        :rtype: (:class:`float`)
        """
        sorted_edges = torch.sort(matrix.flatten(), descending=True).values
        if avg_degree * num_nodes > len(sorted_edges):
            return -np.inf
        return sorted_edges[avg_degree * num_nodes - 1]

    def __neighbors_to_graph__(self, neighbors, neighbor_weights,
                               normalization='row', device='cpu'):
        r"""Combine a list of neighbors and neighbor weights to create a sparse
        graph.
        Args:
            neighbors (List[List[int]]): List of neighbors for each node.
            neighbor_weights (List[List[float]]): List of weights for the
                neighbors of each node.
            normalization (str): Normalization of resulting matrix
                (options: :obj:`"row"`, :obj:`"col"`). (default: :obj:`"row"`)
            device (torch.device): Device to create output tensors on.
                (default: :obj:`"cpu"`)
        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        edge_weight = torch.Tensor(np.concatenate(neighbor_weights)).to(device)
        i = np.repeat(np.arange(len(neighbors)),
                      np.fromiter(map(len, neighbors), dtype=np.int))
        j = np.concatenate(neighbors)
        if normalization == 'col':
            edge_index = torch.Tensor(np.vstack([j, i])).to(device)
            N = len(neighbors)
            edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)
        elif normalization == 'row':
            edge_index = torch.Tensor(np.vstack([i, j])).to(device)
        else:
            raise ValueError(
                f"PPR matrix normalization {normalization} unknown.")
        return edge_index, edge_weight

    @staticmethod
    @jit()
    def __calc_ppr__(indptr, indices, out_degree, alpha, eps):
        r"""Calculate the personalized PageRank vector for all nodes
        using a variant of the Andersen algorithm
        (see Andersen et al. :Local Graph Partitioning using PageRank Vectors.)
        Args:
            indptr (np.ndarray): Index pointer for the sparse matrix
                (CSR-format).
            indices (np.ndarray): Indices of the sparse matrix entries
                (CSR-format).
            out_degree (np.ndarray): Out-degree of each node.
            alpha (float): Alpha of the PageRank to calculate.
            eps (float): Threshold for PPR calculation stopping criterion
                (:obj:`edge_weight >= eps * out_degree`).
        :rtype: (:class:`List[List[int]]`, :class:`List[List[float]]`)
        """
        alpha_eps = alpha * eps
        js = []
        vals = []
        for inode in range(len(out_degree)):
            p = {inode: 0.0}
            r = {}
            r[inode] = alpha
            q = [inode]
            while len(q) > 0:
                unode = q.pop()

                res = r[unode] if unode in r else 0
                if unode in p:
                    p[unode] += res
                else:
                    p[unode] = res
                r[unode] = 0
                for vnode in indices[indptr[unode]:indptr[unode + 1]]:
                    _val = (1 - alpha) * res / out_degree[unode]
                    if vnode in r:
                        r[vnode] += _val
                    else:
                        r[vnode] = _val

                    res_vnode = r[vnode] if vnode in r else 0
                    if res_vnode >= alpha_eps * out_degree[vnode]:
                        if vnode not in q:
                            q.append(vnode)
            js.append(list(p.keys()))
            vals.append(list(p.values()))
        return js, vals

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

================================================
FILE: imports/preprocess_data.py
================================================
# Copyright (c) 2019 Mwiza Kunda
# Copyright (C) 2017 Sarah Parisot <s.parisot@imperial.ac.uk>, Sofia Ira Ktena <ira.ktena@imperial.ac.uk>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implcd ied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import os
import warnings
import glob
import csv
import re
import numpy as np
import scipy.io as sio
import sys
from nilearn import connectome
import pandas as pd
from scipy.spatial import distance
from scipy import signal
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import Normalizer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
warnings.filterwarnings("ignore")

# Input data variables

root_folder = '/home/azureuser/projects/BrainGNN/data/'
data_folder = os.path.join(root_folder, 'ABIDE_pcp/cpac/filt_noglobal')
phenotype = os.path.join(root_folder, 'ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv')


def fetch_filenames(subject_IDs, file_type, atlas):
    """
        subject_list : list of short subject IDs in string format
        file_type    : must be one of the available file types
        filemapping  : resulting file name format
    returns:
        filenames    : list of filetypes (same length as subject_list)
    """

    filemapping = {'func_preproc': '_func_preproc.nii.gz',
                   'rois_' + atlas: '_rois_' + atlas + '.1D'}
    # The list to be filled
    filenames = []

    # Fill list with requested file paths
    for i in range(len(subject_IDs)):
        os.chdir(data_folder)
        try:
            try:
                os.chdir(data_folder)
                filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0])
            except:
                os.chdir(data_folder + '/' + subject_IDs[i])
                filenames.append(glob.glob('*' + subject_IDs[i] + filemapping[file_type])[0])
        except IndexError:
            filenames.append('N/A')
    return filenames


# Get timeseries arrays for list of subjects
def get_timeseries(subject_list, atlas_name, silence=False):
    """
        subject_list : list of short subject IDs in string format
        atlas_name   : the atlas based on which the timeseries are generated e.g. aal, cc200
    returns:
        time_series  : list of timeseries arrays, each of shape (timepoints x regions)
    """

    timeseries = []
    for i in range(len(subject_list)):
        subject_folder = os.path.join(data_folder, subject_list[i])
        ro_file = [f for f in os.listdir(subject_folder) if f.endswith('_rois_' + atlas_name + '.1D')]
        fl = os.path.join(subject_folder, ro_file[0])
        if silence != True:
            print("Reading timeseries file %s" % fl)
        timeseries.append(np.loadtxt(fl, skiprows=0))

    return timeseries


#  compute connectivity matrices
def subject_connectivity(timeseries, subjects, atlas_name, kind, iter_no='', seed=1234,
                         n_subjects='', save=True, save_path=data_folder):
    """
        timeseries   : timeseries table for subject (timepoints x regions)
        subjects     : subject IDs
        atlas_name   : name of the parcellation atlas used
        kind         : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
        iter_no      : tangent connectivity iteration number for cross validation evaluation
        save         : save the connectivity matrix to a file
        save_path    : specify path to save the matrix if different from subject folder
    returns:
        connectivity : connectivity matrix (regions x regions)
    """

    if kind in ['TPE', 'TE', 'correlation','partial correlation']:
        if kind not in ['TPE', 'TE']:
            conn_measure = connectome.ConnectivityMeasure(kind=kind)
            connectivity = conn_measure.fit_transform(timeseries)
        else:
            if kind == 'TPE':
                conn_measure = connectome.ConnectivityMeasure(kind='correlation')
                conn_mat = conn_measure.fit_transform(timeseries)
                conn_measure = connectome.ConnectivityMeasure(kind='tangent')
                connectivity_fit = conn_measure.fit(conn_mat)
                connectivity = connectivity_fit.transform(conn_mat)
            else:
                conn_measure = connectome.ConnectivityMeasure(kind='tangent')
                connectivity_fit = conn_measure.fit(timeseries)
                connectivity = connectivity_fit.transform(timeseries)

    if save:
        if kind not in ['TPE', 'TE']:
            for i, subj_id in enumerate(subjects):
                subject_file = os.path.join(save_path, subj_id,
                                            subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '.mat')
                sio.savemat(subject_file, {'connectivity': connectivity[i]})
            return connectivity
        else:
            for i, subj_id in enumerate(subjects):
                subject_file = os.path.join(save_path, subj_id,
                                            subj_id + '_' + atlas_name + '_' + kind.replace(' ', '_') + '_' + str(
                                                iter_no) + '_' + str(seed) + '_' + validation_ext + str(
                                                n_subjects) + '.mat')
                sio.savemat(subject_file, {'connectivity': connectivity[i]})
            return connectivity_fit


# Get the list of subject IDs

def get_ids(num_subjects=None):
    """
    return:
        subject_IDs    : list of all subject IDs
    """

    subject_IDs = np.genfromtxt(os.path.join(data_folder, 'subject_IDs.txt'), dtype=str)

    if num_subjects is not None:
        subject_IDs = subject_IDs[:num_subjects]

    return subject_IDs


# Get phenotype values for a list of subjects
def get_subject_score(subject_list, score):
    scores_dict = {}

    with open(phenotype) as csv_file:
        reader = csv.DictReader(csv_file)
        for row in reader:
            if row['SUB_ID'] in subject_list:
                if score == 'HANDEDNESS_CATEGORY':
                    if (row[score].strip() == '-9999') or (row[score].strip() == ''):
                        scores_dict[row['SUB_ID']] = 'R'
                    elif row[score] == 'Mixed':
                        scores_dict[row['SUB_ID']] = 'Ambi'
                    elif row[score] == 'L->R':
                        scores_dict[row['SUB_ID']] = 'Ambi'
                    else:
                        scores_dict[row['SUB_ID']] = row[score]
                elif (score == 'FIQ' or score == 'PIQ' or score == 'VIQ'):
                    if (row[score].strip() == '-9999') or (row[score].strip() == ''):
                        scores_dict[row['SUB_ID']] = 100
                    else:
                        scores_dict[row['SUB_ID']] = float(row[score])

                else:
                    scores_dict[row['SUB_ID']] = row[score]

    return scores_dict


# preprocess phenotypes. Categorical -> ordinal representation
def preprocess_phenotypes(pheno_ft, params):
    if params['model'] == 'MIDA':
        ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2])], remainder='passthrough')
    else:
        ct = ColumnTransformer([("ordinal", OrdinalEncoder(), [0, 1, 2, 3])], remainder='passthrough')

    pheno_ft = ct.fit_transform(pheno_ft)
    pheno_ft = pheno_ft.astype('float32')

    return (pheno_ft)


# create phenotype feature vector to concatenate with fmri feature vectors
def phenotype_ft_vector(pheno_ft, num_subjects, params):
    gender = pheno_ft[:, 0]
    if params['model'] == 'MIDA':
        eye = pheno_ft[:, 0]
        hand = pheno_ft[:, 2]
        age = pheno_ft[:, 3]
        fiq = pheno_ft[:, 4]
    else:
        eye = pheno_ft[:, 2]
        hand = pheno_ft[:, 3]
        age = pheno_ft[:, 4]
        fiq = pheno_ft[:, 5]

    phenotype_ft = np.zeros((num_subjects, 4))
    phenotype_ft_eye = np.zeros((num_subjects, 2))
    phenotype_ft_hand = np.zeros((num_subjects, 3))

    for i in range(num_subjects):
        phenotype_ft[i, int(gender[i])] = 1
        phenotype_ft[i, -2] = age[i]
        phenotype_ft[i, -1] = fiq[i]
        phenotype_ft_eye[i, int(eye[i])] = 1
        phenotype_ft_hand[i, int(hand[i])] = 1

    if params['model'] == 'MIDA':
        phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand], axis=1)
    else:
        phenotype_ft = np.concatenate([phenotype_ft, phenotype_ft_hand, phenotype_ft_eye], axis=1)

    return phenotype_ft


# Load precomputed fMRI connectivity networks
def get_networks(subject_list, kind, iter_no='', seed=1234, n_subjects='', atlas_name="aal",
                 variable='connectivity'):
    """
        subject_list : list of subject IDs
        kind         : the kind of connectivity to be used, e.g. lasso, partial correlation, correlation
        atlas_name   : name of the parcellation atlas used
        variable     : variable name in the .mat file that has been used to save the precomputed networks
    return:
        matrix      : feature matrix of connectivity networks (num_subjects x network_size)
    """

    all_networks = []
    for subject in subject_list:
        if len(kind.split()) == 2:
            kind = '_'.join(kind.split())
        fl = os.path.join(data_folder, subject,
                              subject + "_" + atlas_name + "_" + kind.replace(' ', '_') + ".mat")


        matrix = sio.loadmat(fl)[variable]
        all_networks.append(matrix)

    if kind in ['TE', 'TPE']:
        norm_networks = [mat for mat in all_networks]
    else:
        norm_networks = [np.arctanh(mat) for mat in all_networks]

    networks = np.stack(norm_networks)

    return networks



================================================
FILE: imports/read_abide_stats_parall.py
================================================
'''
Author: Xiaoxiao Li
Date: 2019/02/24
'''

import os.path as osp
from os import listdir
import os
import glob
import h5py

import torch
import numpy as np
from scipy.io import loadmat
from torch_geometric.data import Data
import networkx as nx
from networkx.convert_matrix import from_numpy_matrix
import multiprocessing
from torch_sparse import coalesce
from torch_geometric.utils import remove_self_loops
from functools import partial
import deepdish as dd
from imports.gdc import GDC


def split(data, batch):
    node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0)
    node_slice = torch.cat([torch.tensor([0]), node_slice])

    row, _ = data.edge_index
    edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0)
    edge_slice = torch.cat([torch.tensor([0]), edge_slice])

    # Edge indices should start at zero for every graph.
    data.edge_index -= node_slice[batch[row]].unsqueeze(0)

    slices = {'edge_index': edge_slice}
    if data.x is not None:
        slices['x'] = node_slice
    if data.edge_attr is not None:
        slices['edge_attr'] = edge_slice
    if data.y is not None:
        if data.y.size(0) == batch.size(0):
            slices['y'] = node_slice
        else:
            slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long)
    if data.pos is not None:
        slices['pos'] = node_slice

    return data, slices


def cat(seq):
    seq = [item for item in seq if item is not None]
    seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
    return torch.cat(seq, dim=-1).squeeze() if len(seq) > 0 else None

class NoDaemonProcess(multiprocessing.Process):
    @property
    def daemon(self):
        return False

    @daemon.setter
    def daemon(self, value):
        pass


class NoDaemonContext(type(multiprocessing.get_context())):
    Process = NoDaemonProcess


def read_data(data_dir):
    onlyfiles = [f for f in listdir(data_dir) if osp.isfile(osp.join(data_dir, f))]
    onlyfiles.sort()
    batch = []
    pseudo = []
    y_list = []
    edge_att_list, edge_index_list,att_list = [], [], []

    # parallar computing
    cores = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(processes=cores)
    #pool =  MyPool(processes = cores)
    func = partial(read_sigle_data, data_dir)

    import timeit

    start = timeit.default_timer()

    res = pool.map(func, onlyfiles)

    pool.close()
    pool.join()

    stop = timeit.default_timer()

    print('Time: ', stop - start)



    for j in range(len(res)):
        edge_att_list.append(res[j][0])
        edge_index_list.append(res[j][1]+j*res[j][4])
        att_list.append(res[j][2])
        y_list.append(res[j][3])
        batch.append([j]*res[j][4])
        pseudo.append(np.diag(np.ones(res[j][4])))

    edge_att_arr = np.concatenate(edge_att_list)
    edge_index_arr = np.concatenate(edge_index_list, axis=1)
    att_arr = np.concatenate(att_list, axis=0)
    pseudo_arr = np.concatenate(pseudo, axis=0)
    y_arr = np.stack(y_list)
    edge_att_torch = torch.from_numpy(edge_att_arr.reshape(len(edge_att_arr), 1)).float()
    att_torch = torch.from_numpy(att_arr).float()
    y_torch = torch.from_numpy(y_arr).long()  # classification
    batch_torch = torch.from_numpy(np.hstack(batch)).long()
    edge_index_torch = torch.from_numpy(edge_index_arr).long()
    pseudo_torch = torch.from_numpy(pseudo_arr).float()
    data = Data(x=att_torch, edge_index=edge_index_torch, y=y_torch, edge_attr=edge_att_torch, pos = pseudo_torch )


    data, slices = split(data, batch_torch)

    return data, slices


def read_sigle_data(data_dir,filename,use_gdc =False):

    temp = dd.io.load(osp.join(data_dir, filename))

    # read edge and edge attribute
    pcorr = np.abs(temp['pcorr'][()])

    num_nodes = pcorr.shape[0]
    G = from_numpy_matrix(pcorr)
    A = nx.to_scipy_sparse_matrix(G)
    adj = A.tocoo()
    edge_att = np.zeros(len(adj.row))
    for i in range(len(adj.row)):
        edge_att[i] = pcorr[adj.row[i], adj.col[i]]

    edge_index = np.stack([adj.row, adj.col])
    edge_index, edge_att = remove_self_loops(torch.from_numpy(edge_index), torch.from_numpy(edge_att))
    edge_index = edge_index.long()
    edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes,
                                    num_nodes)
    att = temp['corr'][()]
    label = temp['label'][()]

    att_torch = torch.from_numpy(att).float()
    y_torch = torch.from_numpy(np.array(label)).long()  # classification

    data = Data(x=att_torch, edge_index=edge_index.long(), y=y_torch, edge_attr=edge_att)

    if use_gdc:
        '''
        Implementation of https://papers.nips.cc/paper/2019/hash/23c894276a2c5a16470e6a31f4618d73-Abstract.html
        '''
        data.edge_attr = data.edge_attr.squeeze()
        gdc = GDC(self_loop_weight=1, normalization_in='sym',
                  normalization_out='col',
                  diffusion_kwargs=dict(method='ppr', alpha=0.2),
                  sparsification_kwargs=dict(method='topk', k=20,
                                             dim=0), exact=True)
        data = gdc(data)
        return data.edge_attr.data.numpy(),data.edge_index.data.numpy(),data.x.data.numpy(),data.y.data.item(),num_nodes

    else:
        return edge_att.data.numpy(),edge_index.data.numpy(),att,label,num_nodes

if __name__ == "__main__":
    data_dir = '/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal/raw'
    filename = '50346.h5'
    read_sigle_data(data_dir, filename)








================================================
FILE: imports/utils.py
================================================
from scipy import stats
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.io import loadmat
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold


def train_val_test_split(kfold = 5, fold = 0):
    n_sub = 1035
    id = list(range(n_sub))


    import random
    random.seed(123)
    random.shuffle(id)

    kf = KFold(n_splits=kfold, random_state=123,shuffle = True)
    kf2 = KFold(n_splits=kfold-1, shuffle=True, random_state = 666)


    test_index = list()
    train_index = list()
    val_index = list()

    for tr,te in kf.split(np.array(id)):
        test_index.append(te)
        tr_id, val_id = list(kf2.split(tr))[0]
        train_index.append(tr[tr_id])
        val_index.append(tr[val_id])

    train_id = train_index[fold]
    test_id = test_index[fold]
    val_id = val_index[fold]

    return train_id,val_id,test_id

================================================
FILE: net/braingnn.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.utils import (add_self_loops, sort_edge_index,
                                   remove_self_loops)
from torch_sparse import spspmm

from net.braingraphconv import MyNNConv


##########################################################################################################################
class Network(torch.nn.Module):
    def __init__(self, indim, ratio, nclass, k=8, R=200):
        '''

        :param indim: (int) node feature dimension
        :param ratio: (float) pooling ratio in (0,1)
        :param nclass: (int)  number of classes
        :param k: (int) number of communities
        :param R: (int) number of ROIs
        '''
        super(Network, self).__init__()

        self.indim = indim
        self.dim1 = 32
        self.dim2 = 32
        self.dim3 = 512
        self.dim4 = 256
        self.dim5 = 8
        self.k = k
        self.R = R

        self.n1 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim1 * self.indim))
        self.conv1 = MyNNConv(self.indim, self.dim1, self.n1, normalize=False)
        self.pool1 = TopKPooling(self.dim1, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)
        self.n2 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim2 * self.dim1))
        self.conv2 = MyNNConv(self.dim1, self.dim2, self.n2, normalize=False)
        self.pool2 = TopKPooling(self.dim2, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)

        #self.fc1 = torch.nn.Linear((self.dim2) * 2, self.dim2)
        self.fc1 = torch.nn.Linear((self.dim1+self.dim2)*2, self.dim2)
        self.bn1 = torch.nn.BatchNorm1d(self.dim2)
        self.fc2 = torch.nn.Linear(self.dim2, self.dim3)
        self.bn2 = torch.nn.BatchNorm1d(self.dim3)
        self.fc3 = torch.nn.Linear(self.dim3, nclass)




    def forward(self, x, edge_index, batch, edge_attr, pos):

        x = self.conv1(x, edge_index, edge_attr, pos)
        x, edge_index, edge_attr, batch, perm, score1 = self.pool1(x, edge_index, edge_attr, batch)

        pos = pos[perm]
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        edge_attr = edge_attr.squeeze()
        edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0))

        x = self.conv2(x, edge_index, edge_attr, pos)
        x, edge_index, edge_attr, batch, perm, score2 = self.pool2(x, edge_index,edge_attr, batch)

        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = torch.cat([x1,x2], dim=1)
        x = self.bn1(F.relu(self.fc1(x)))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.bn2(F.relu(self.fc2(x)))
        x= F.dropout(x, p=0.5, training=self.training)
        x = F.log_softmax(self.fc3(x), dim=-1)

        return x,self.pool1.weight,self.pool2.weight, torch.sigmoid(score1).view(x.size(0),-1), torch.sigmoid(score2).view(x.size(0),-1)

    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight



================================================
FILE: net/braingraphconv.py
================================================
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from net.brainmsgpassing import MyMessagePassing
from torch_geometric.utils import add_remaining_self_loops,softmax

from torch_geometric.typing import (OptTensor)

from net.inits import uniform


class MyNNConv(MyMessagePassing):
    def __init__(self, in_channels, out_channels, nn, normalize=False, bias=True,
                 **kwargs):
        super(MyNNConv, self).__init__(aggr='mean', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.nn = nn
        #self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
#        uniform(self.in_channels, self.weight)
        uniform(self.in_channels, self.bias)

    def forward(self, x, edge_index, edge_weight=None, pseudo= None, size=None):
        """"""
        edge_weight = edge_weight.squeeze()
        if size is None and torch.is_tensor(x):
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, 1, x.size(0))

        weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
        if torch.is_tensor(x):
            x = torch.matmul(x.unsqueeze(1), weight).squeeze(1)
        else:
            x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1),
                 None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1))

        # weight = self.nn(pseudo).view(-1, self.out_channels,self.in_channels)
        # if torch.is_tensor(x):
        #     x = torch.matmul(x.unsqueeze(1), weight.permute(0,2,1)).squeeze(1)
        # else:
        #     x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1),
        #          None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1))

        return self.propagate(edge_index, size=size, x=x,
                              edge_weight=edge_weight)

    def message(self, edge_index_i, size_i, x_j, edge_weight, ptr: OptTensor):
        edge_weight = softmax(edge_weight, edge_index_i, ptr, size_i)
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)



================================================
FILE: net/brainmsgpassing.py
================================================
import sys
import inspect

import torch
# from torch_geometric.utils import scatter_
from torch_scatter import scatter,scatter_add

special_args = [
    'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j'
]
__size_error_msg__ = ('All tensors which should get mapped to the same source '
                      'or target nodes must be of same size in dimension 0.')

is_python2 = sys.version_info[0] < 3
getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec


class MyMessagePassing(torch.nn.Module):
    r"""Base class for creating message passing layers
    .. math::
        \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
        \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
        \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
    where :math:`\square` denotes a differentiable, permutation invariant
    function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
    and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
    MLPs.
    See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
    create_gnn.html>`__ for the accompanying tutorial.
    Args:
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
            (default: :obj:`"add"`)
        flow (string, optional): The flow direction of message passing
            (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
            (default: :obj:`"source_to_target"`)
        node_dim (int, optional): The axis along which to propagate.
            (default: :obj:`0`)
    """
    def __init__(self, aggr='add', flow='source_to_target', node_dim=0):
        super(MyMessagePassing, self).__init__()

        self.aggr = aggr
        assert self.aggr in ['add', 'mean', 'max']

        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        self.node_dim = node_dim
        assert self.node_dim >= 0

        self.__message_args__ = getargspec(self.message)[0][1:]
        self.__special_args__ = [(i, arg)
                                 for i, arg in enumerate(self.__message_args__)
                                 if arg in special_args]
        self.__message_args__ = [
            arg for arg in self.__message_args__ if arg not in special_args
        ]
        self.__update_args__ = getargspec(self.update)[0][2:]

    def propagate(self, edge_index, size=None, **kwargs):
        r"""The initial call to start propagating messages.
        Args:
            edge_index (Tensor): The indices of a general (sparse) assignment
                matrix with shape :obj:`[N, M]` (can be directed or
                undirected).
            size (list or tuple, optional): The size :obj:`[N, M]` of the
                assignment matrix. If set to :obj:`None`, the size is tried to
                get automatically inferred and assumed to be symmetric.
                (default: :obj:`None`)
            **kwargs: Any additional data which is needed to construct messages
                and to update node embeddings.
        """

        dim = self.node_dim
        size = [None, None] if size is None else list(size)
        assert len(size) == 2

        i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
        ij = {"_i": i, "_j": j}

        message_args = []
        for arg in self.__message_args__:
            if arg[-2:] in ij.keys():
                tmp = kwargs.get(arg[:-2], None)
                if tmp is None:  # pragma: no cover
                    message_args.append(tmp)
                else:
                    idx = ij[arg[-2:]]
                    if isinstance(tmp, tuple) or isinstance(tmp, list):
                        assert len(tmp) == 2
                        if tmp[1 - idx] is not None:
                            if size[1 - idx] is None:
                                size[1 - idx] = tmp[1 - idx].size(dim)
                            if size[1 - idx] != tmp[1 - idx].size(dim):
                                raise ValueError(__size_error_msg__)
                        tmp = tmp[idx]

                    if tmp is None:
                        message_args.append(tmp)
                    else:
                        if size[idx] is None:
                            size[idx] = tmp.size(dim)
                        if size[idx] != tmp.size(dim):
                            raise ValueError(__size_error_msg__)

                        tmp = torch.index_select(tmp, dim, edge_index[idx])
                        message_args.append(tmp)
            else:
                message_args.append(kwargs.get(arg, None))

        size[0] = size[1] if size[0] is None else size[0]
        size[1] = size[0] if size[1] is None else size[1]

        kwargs['edge_index'] = edge_index
        kwargs['size'] = size

        for (idx, arg) in self.__special_args__:
            if arg[-2:] in ij.keys():
                message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]])
            else:
                message_args.insert(idx, kwargs[arg])

        update_args = [kwargs[arg] for arg in self.__update_args__]

        out = self.message(*message_args)
        # out = scatter_(self.aggr, out, edge_index[i], dim, dim_size=size[i])
        out = scatter_add(out, edge_index[i], dim, dim_size=size[i])
        out = self.update(out, *update_args)

        return out

    def message(self, x_j):  # pragma: no cover
        r"""Constructs messages to node :math:`i` in analogy to
        :math:`\phi_{\mathbf{\Theta}}` for each edge in
        :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and
        :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
        Can take any argument which was initially passed to :meth:`propagate`.
        In addition, tensors passed to :meth:`propagate` can be mapped to the
        respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
        :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
        """

        return x_j

    def update(self, aggr_out):  # pragma: no cover
        r"""Updates node embeddings in analogy to
        :math:`\gamma_{\mathbf{\Theta}}` for each node
        :math:`i \in \mathcal{V}`.
        Takes in the output of aggregation as first argument and any argument
        which was initially passed to :meth:`propagate`."""

        return aggr_out


================================================
FILE: net/inits.py
================================================
import math


def uniform(size, tensor):
    bound = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-bound, bound)


def kaiming_uniform(tensor, fan, a):
    if tensor is not None:
        bound = math.sqrt(6 / ((1 + a**2) * fan))
        tensor.data.uniform_(-bound, bound)


def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)


def ones(tensor):
    if tensor is not None:
        tensor.data.fill_(1)

================================================
FILE: requirements.txt
================================================
alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work
anaconda-client==1.7.2
anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1610472525955/work
anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist
appdirs==1.4.4
argh==0.26.2
argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work
arrow==0.13.1
ase==3.21.1
asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
astroid @ file:///tmp/build/80754af9/astroid_1613500854201/work
astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work
async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work
atomicwrites==1.4.0
attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work
Babel @ file:///tmp/build/80754af9/babel_1607110387436/work
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work
beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work
binaryornot @ file:///tmp/build/80754af9/binaryornot_1617751525010/work
bitarray @ file:///tmp/build/80754af9/bitarray_1618431750766/work
bkcharts==0.2
black==19.10b0
bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
bokeh @ file:///tmp/build/80754af9/bokeh_1617824541184/work
boto==2.49.0
Bottleneck==1.3.2
brotlipy==0.7.0
certifi==2020.12.5
cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work
chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work
click @ file:///home/linux1/recipes/ci/click_1610990599742/work
cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
clyent==1.2.2
colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
contextlib2==0.6.0.post1
cookiecutter @ file:///tmp/build/80754af9/cookiecutter_1617748928239/work
cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work
cycler==0.10.0
Cython @ file:///tmp/build/80754af9/cython_1618435160151/work
cytoolz==0.11.0
dask @ file:///tmp/build/80754af9/dask-core_1617390489108/work
decorator @ file:///tmp/build/80754af9/decorator_1617916966915/work
deepdish==0.3.6
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
distributed @ file:///tmp/build/80754af9/distributed_1617381497899/work
docutils @ file:///tmp/build/80754af9/docutils_1617624660125/work
entrypoints==0.3
et-xmlfile==1.0.1
fastcache==1.1.0
filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work
flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work
Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work
fsspec @ file:///tmp/build/80754af9/fsspec_1617959894824/work
future==0.18.2
gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work
glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work
gmpy2==2.0.8
googledrivedownloader==0.4
greenlet @ file:///tmp/build/80754af9/greenlet_1611957705398/work
h5py==2.10.0
HeapDict==1.0.1
html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work
imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work
imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work
inflection==0.5.1
iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work
intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
isodate==0.6.0
isort @ file:///tmp/build/80754af9/isort_1616355431277/work
itsdangerous @ file:///home/ktietz/src/ci/itsdangerous_1611932585308/work
jdcal==1.4.1
jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work
jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work
Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work
jinja2-time @ file:///tmp/build/80754af9/jinja2-time_1617751524098/work
joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work
json5==0.9.5
jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
jupyter==1.0.0
jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work
jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work
jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work
jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work
jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
keyring @ file:///tmp/build/80754af9/keyring_1614616740399/work
kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work
lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work
libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
llvmlite==0.36.0
locket==0.2.1
lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work
MarkupSafe==1.1.1
matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work
mccabe==0.6.1
mistune==0.8.4
mkl-fft==1.3.0
mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work
mkl-service==2.3.0
mock @ file:///tmp/build/80754af9/mock_1607622725907/work
more-itertools @ file:///tmp/build/80754af9/more-itertools_1613676688952/work
mpmath==1.2.1
msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work
multipledispatch==0.6.0
mypy-extensions==0.4.3
nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work
nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work
nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
nibabel==3.2.1
nilearn==0.7.1
nltk @ file:///tmp/build/80754af9/nltk_1618327084230/work
nose @ file:///tmp/build/80754af9/nose_1606773131901/work
notebook @ file:///tmp/build/80754af9/notebook_1616443462982/work
numba @ file:///tmp/build/80754af9/numba_1616774046117/work
numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1618497241363/work
numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
olefile==0.46
openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work
packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
pandas==1.2.4
pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
parso==0.7.0
partd @ file:///tmp/build/80754af9/partd_1618000087440/work
path @ file:///tmp/build/80754af9/path_1614022220526/work
pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024983162/work
pathspec==0.7.0
pathtools==0.1.2
patsy==0.5.1
pep8==1.7.1
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work
pkginfo==1.7.0
pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work
ply==3.11
poyo @ file:///tmp/build/80754af9/poyo_1617751526755/work
prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1618088486455/work
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
protobuf==3.17.0
psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
py @ file:///tmp/build/80754af9/py_1607971587848/work
pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work
pycosat==0.6.3
pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
pycurl==7.43.0.6
pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1616182067796/work
pyerfa @ file:///tmp/build/80754af9/pyerfa_1619390903914/work
pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work
Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work
pylint @ file:///tmp/build/80754af9/pylint_1617135829881/work
pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work
pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work
pyodbc===4.0.0-unsupported
pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work
pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
pytest==6.2.3
python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work
python-louvain==0.15
python-slugify @ file:///tmp/build/80754af9/python-slugify_1620405669636/work
pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
PyYAML==5.4.1
pyzmq==20.0.0
QDarkStyle @ file:///tmp/build/80754af9/qdarkstyle_1617386714626/work
qstylizer @ file:///tmp/build/80754af9/qstylizer_1617713584600/work/dist/qstylizer-0.1.10-py2.py3-none-any.whl
QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work
qtconsole @ file:///tmp/build/80754af9/qtconsole_1616775094278/work
QtPy==1.9.0
rdflib==5.0.0
regex @ file:///tmp/build/80754af9/regex_1617569202463/work
requests @ file:///tmp/build/80754af9/requests_1608241421344/work
rope @ file:///tmp/build/80754af9/rope_1602264064449/work
Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work
ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work
scikit-image==0.16.2
scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1614446682169/work
scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work
seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work
SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work
Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
simplegeneric==0.8.1
singledispatch @ file:///tmp/build/80754af9/singledispatch_1614366001199/work
sip==4.19.13
six @ file:///tmp/build/80754af9/six_1605205327372/work
sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work
snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work
sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work
sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work
soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work
Sphinx @ file:///tmp/build/80754af9/sphinx_1616268783226/work
sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work
sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work
sphinxcontrib-htmlhelp @ file:///home/ktietz/src/ci/sphinxcontrib-htmlhelp_1611920974801/work
sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work
sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work
sphinxcontrib-serializinghtml @ file:///home/ktietz/src/ci/sphinxcontrib-serializinghtml_1611920755253/work
sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
spyder @ file:///tmp/build/80754af9/spyder_1618327905127/work
spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1617396566288/work
SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1618089170652/work
statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work
sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work
tables==3.6.1
tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
tensorboardX==2.2
terminado==0.9.4
testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
text-unidecode==1.3
textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work
threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work
tinycss @ file:///tmp/build/80754af9/tinycss_1617713798712/work
toml @ file:///tmp/build/80754af9/toml_1616166611790/work
toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work
torch==1.7.0
torch-cluster==1.5.9
torch-geometric==1.7.0
torch-scatter==2.0.6
torch-sparse==0.6.9
torch-spline-conv==1.2.1
torchaudio==0.7.0a0+ac17b64
torchvision==0.8.0
tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work
tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work
traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
tsBNgen==1.0.0
typed-ast @ file:///tmp/build/80754af9/typed-ast_1610484547928/work
typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work
unicodecsv==0.14.1
Unidecode @ file:///tmp/build/80754af9/unidecode_1614712377438/work
urllib3 @ file:///tmp/build/80754af9/urllib3_1615837158687/work
watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
webencodings==0.5.1
Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work
whichcraft @ file:///tmp/build/80754af9/whichcraft_1617751293875/work
widgetsnbextension==3.5.1
wrapt==1.12.1
wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work
xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work
XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1617224712951/work
xlwt==1.3.0
yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work
zict==2.0.0
zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
zope.event==4.5.0
zope.interface @ file:///tmp/build/80754af9/zope.interface_1616357211867/work
Download .txt
gitextract_qan2yim0/

├── .idea/
│   ├── .gitignore
│   ├── GNN_biomarker_MEDIA.iml
│   ├── deployment.xml
│   ├── encodings.xml
│   ├── inspectionProfiles/
│   │   └── Project_Default.xml
│   ├── misc.xml
│   ├── modules.xml
│   └── webServers.xml
├── 01-fetch_data.py
├── 02-process_data.py
├── 03-main.py
├── README.md
├── data/
│   └── subject_ID.txt
├── imports/
│   ├── ABIDEDataset.py
│   ├── __inits__.py
│   ├── gdc.py
│   ├── preprocess_data.py
│   ├── read_abide_stats_parall.py
│   └── utils.py
├── net/
│   ├── braingnn.py
│   ├── braingraphconv.py
│   ├── brainmsgpassing.py
│   └── inits.py
└── requirements.txt
Download .txt
SYMBOL INDEX (68 symbols across 12 files)

FILE: 01-fetch_data.py
  function str2bool (line 36) | def str2bool(v):
  function main (line 47) | def main():

FILE: 02-process_data.py
  function str2bool (line 31) | def str2bool(v):
  function main (line 42) | def main():

FILE: 03-main.py
  function topk_loss (line 96) | def topk_loss(s,ratio):
  function consist_loss (line 104) | def consist_loss(s):
  function train (line 116) | def train(epoch):
  function test_acc (line 163) | def test_acc(loader):
  function test_loss (line 174) | def test_loss(loader,epoch):

FILE: imports/ABIDEDataset.py
  class ABIDEDataset (line 10) | class ABIDEDataset(InMemoryDataset):
    method __init__ (line 11) | def __init__(self, root, name, transform=None, pre_transform=None):
    method raw_file_names (line 18) | def raw_file_names(self):
    method processed_file_names (line 24) | def processed_file_names(self):
    method download (line 27) | def download(self):
    method process (line 31) | def process(self):
    method __repr__ (line 47) | def __repr__(self):

FILE: imports/gdc.py
  function jit (line 10) | def jit():
  class GDC (line 20) | class GDC(object):
    method __init__ (line 70) | def __init__(self, self_loop_weight=1, normalization_in='sym',
    method __call__ (line 86) | def __call__(self, data):
    method transition_matrix (line 127) | def transition_matrix(self, edge_index, edge_weight, num_nodes,
    method diffusion_matrix_exact (line 173) | def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes,
    method diffusion_matrix_approx (line 232) | def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes,
    method sparsify_dense (line 305) | def sparsify_dense(self, matrix, method, **kwargs):
    method sparsify_sparse (line 363) | def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method,
    method __expm__ (line 397) | def __expm__(self, matrix, symmetric):
    method __calculate_eps__ (line 412) | def __calculate_eps__(self, matrix, num_nodes, avg_degree):
    method __neighbors_to_graph__ (line 425) | def __neighbors_to_graph__(self, neighbors, neighbor_weights,
    method __calc_ppr__ (line 456) | def __calc_ppr__(indptr, indices, out_degree, alpha, eps):
    method __repr__ (line 503) | def __repr__(self):

FILE: imports/preprocess_data.py
  function fetch_filenames (line 44) | def fetch_filenames(subject_IDs, file_type, atlas):
  function get_timeseries (line 74) | def get_timeseries(subject_list, atlas_name, silence=False):
  function subject_connectivity (line 95) | def subject_connectivity(timeseries, subjects, atlas_name, kind, iter_no...
  function get_ids (line 144) | def get_ids(num_subjects=None):
  function get_subject_score (line 159) | def get_subject_score(subject_list, score):
  function preprocess_phenotypes (line 188) | def preprocess_phenotypes(pheno_ft, params):
  function phenotype_ft_vector (line 201) | def phenotype_ft_vector(pheno_ft, num_subjects, params):
  function get_networks (line 234) | def get_networks(subject_list, kind, iter_no='', seed=1234, n_subjects='...

FILE: imports/read_abide_stats_parall.py
  function split (line 26) | def split(data, batch):
  function cat (line 53) | def cat(seq):
  class NoDaemonProcess (line 58) | class NoDaemonProcess(multiprocessing.Process):
    method daemon (line 60) | def daemon(self):
    method daemon (line 64) | def daemon(self, value):
  class NoDaemonContext (line 68) | class NoDaemonContext(type(multiprocessing.get_context())):
  function read_data (line 72) | def read_data(data_dir):
  function read_sigle_data (line 128) | def read_sigle_data(data_dir,filename,use_gdc =False):

FILE: imports/utils.py
  function train_val_test_split (line 10) | def train_val_test_split(kfold = 5, fold = 0):

FILE: net/braingnn.py
  class Network (line 14) | class Network(torch.nn.Module):
    method __init__ (line 15) | def __init__(self, indim, ratio, nclass, k=8, R=200):
    method forward (line 52) | def forward(self, x, edge_index, batch, edge_attr, pos):
    method augment_adj (line 77) | def augment_adj(self, edge_index, edge_weight, num_nodes):

FILE: net/braingraphconv.py
  class MyNNConv (line 12) | class MyNNConv(MyMessagePassing):
    method __init__ (line 13) | def __init__(self, in_channels, out_channels, nn, normalize=False, bia...
    method reset_parameters (line 30) | def reset_parameters(self):
    method forward (line 34) | def forward(self, x, edge_index, edge_weight=None, pseudo= None, size=...
    method message (line 58) | def message(self, edge_index_i, size_i, x_j, edge_weight, ptr: OptTens...
    method update (line 62) | def update(self, aggr_out):
    method __repr__ (line 69) | def __repr__(self):

FILE: net/brainmsgpassing.py
  class MyMessagePassing (line 18) | class MyMessagePassing(torch.nn.Module):
    method __init__ (line 40) | def __init__(self, aggr='add', flow='source_to_target', node_dim=0):
    method propagate (line 61) | def propagate(self, edge_index, size=None, **kwargs):
    method message (line 133) | def message(self, x_j):  # pragma: no cover
    method update (line 146) | def update(self, aggr_out):  # pragma: no cover

FILE: net/inits.py
  function uniform (line 4) | def uniform(size, tensor):
  function kaiming_uniform (line 10) | def kaiming_uniform(tensor, fan, a):
  function glorot (line 16) | def glorot(tensor):
  function zeros (line 22) | def zeros(tensor):
  function ones (line 27) | def ones(tensor):
Condensed preview — 24 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (121K chars).
[
  {
    "path": ".idea/.gitignore",
    "chars": 176,
    "preview": "# Default ignored files\n/shelf/\n/workspace.xml\n# Datasource local storage ignored files\n/dataSources/\n/dataSources.local"
  },
  {
    "path": ".idea/GNN_biomarker_MEDIA.iml",
    "chars": 421,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager"
  },
  {
    "path": ".idea/deployment.xml",
    "chars": 870,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"PublishConfigData\" serverName=\"li-gan\">\n"
  },
  {
    "path": ".idea/encodings.xml",
    "chars": 135,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"Encoding\" addBOMForNewFiles=\"with NO BOM"
  },
  {
    "path": ".idea/inspectionProfiles/Project_Default.xml",
    "chars": 16512,
    "preview": "<component name=\"InspectionProjectProfileManager\">\n  <profile version=\"1.0\">\n    <option name=\"myName\" value=\"Project De"
  },
  {
    "path": ".idea/misc.xml",
    "chars": 396,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectRootManager\" version=\"2\" project-"
  },
  {
    "path": ".idea/modules.xml",
    "chars": 290,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n   "
  },
  {
    "path": ".idea/webServers.xml",
    "chars": 603,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"WebServers\">\n    <option name=\"servers\">"
  },
  {
    "path": "01-fetch_data.py",
    "chars": 3951,
    "preview": "# Copyright (c) 2019 Mwiza Kunda\n# Copyright (C) 2017 Sarah Parisot <s.parisot@imperial.ac.uk>, , Sofia Ira Ktena <ira.k"
  },
  {
    "path": "02-process_data.py",
    "chars": 3562,
    "preview": "# Copyright (c) 2019 Mwiza Kunda\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the"
  },
  {
    "path": "03-main.py",
    "chars": 10678,
    "preview": "import os\nimport numpy as np\nimport argparse\nimport time\nimport copy\n\nimport torch\nimport torch.nn.functional as F\nfrom "
  },
  {
    "path": "README.md",
    "chars": 1393,
    "preview": "# Graph Neural Network for Brain Network Analysis\n A preliminary implementation of BrainGNN. The example presented here "
  },
  {
    "path": "data/subject_ID.txt",
    "chars": 6210,
    "preview": "50128\n51203\n50325\n50117\n50573\n50741\n50779\n51009\n50746\n50574\n50110\n50322\n51036\n51204\n50119\n50126\n50314\n51490\n50784\n51464\n"
  },
  {
    "path": "imports/ABIDEDataset.py",
    "chars": 1671,
    "preview": "import torch\nfrom torch_geometric.data import InMemoryDataset,Data\nfrom os.path import join, isfile\nfrom os import listd"
  },
  {
    "path": "imports/__inits__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "imports/gdc.py",
    "chars": 23308,
    "preview": "import torch\nimport numba\nimport numpy as np\nfrom scipy.linalg import expm\nfrom torch_geometric.utils import add_self_lo"
  },
  {
    "path": "imports/preprocess_data.py",
    "chars": 10249,
    "preview": "# Copyright (c) 2019 Mwiza Kunda\n# Copyright (C) 2017 Sarah Parisot <s.parisot@imperial.ac.uk>, Sofia Ira Ktena <ira.kte"
  },
  {
    "path": "imports/read_abide_stats_parall.py",
    "chars": 5517,
    "preview": "'''\nAuthor: Xiaoxiao Li\nDate: 2019/02/24\n'''\n\nimport os.path as osp\nfrom os import listdir\nimport os\nimport glob\nimport "
  },
  {
    "path": "imports/utils.py",
    "chars": 899,
    "preview": "from scipy import stats\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nfrom scipy.io import loadmat\nfro"
  },
  {
    "path": "net/braingnn.py",
    "chars": 3771,
    "preview": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom torch_geometric.nn import TopKPooling\nfrom torch"
  },
  {
    "path": "net/braingraphconv.py",
    "chars": 2809,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Parameter\nfrom net.brainmsgpassing import MyMessagePas"
  },
  {
    "path": "net/brainmsgpassing.py",
    "chars": 6467,
    "preview": "import sys\nimport inspect\n\nimport torch\n# from torch_geometric.utils import scatter_\nfrom torch_scatter import scatter,s"
  },
  {
    "path": "net/inits.py",
    "chars": 617,
    "preview": "import math\n\n\ndef uniform(size, tensor):\n    bound = 1.0 / math.sqrt(size)\n    if tensor is not None:\n        tensor.dat"
  },
  {
    "path": "requirements.txt",
    "chars": 14531,
    "preview": "alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work\nanaconda-client==1.7.2\nanaconda-project @ file:///tm"
  }
]

About this extraction

This page contains the full source code of the xxlya/BrainGNN_Pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 24 files (112.3 KB), approximately 31.7k tokens, and a symbol index with 68 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!